diff --git a/README.md b/README.md
index 3776a1de..faf78964 100644
--- a/README.md
+++ b/README.md
@@ -71,7 +71,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
> InternVL-Chat(InternLM2) needs to set the parameter '--eos_id 92542 --trust_remote_code'.
-> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code'.
+> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code', and use 'pip install git+https://github.com/huggingface/transformers' to upgrade to the latest version.
> Stablelm needs to set the parameter '--trust_remote_code'.
diff --git a/lightllm/models/internlm_xcomposer/internlm_visual.py b/lightllm/models/internlm_xcomposer/internlm_visual.py
index 868aab0d..15133e2d 100644
--- a/lightllm/models/internlm_xcomposer/internlm_visual.py
+++ b/lightllm/models/internlm_xcomposer/internlm_visual.py
@@ -7,17 +7,19 @@
from typing import List, Union
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
+from rpyc.utils.classic import obtain
+from io import BytesIO
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
class InternVisionModel:
-
def __init__(self):
pass
def load_projector_update(self, config, weight_dir):
- projector_type = config.get("projector_type", 'mlp2x_gelu')
+ projector_type = config.get("projector_type", "mlp2x_gelu")
projector_weights = []
- mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
self.mlp_depth = int(mlp_gelu_match.group(1))
new_dict = {}
@@ -25,54 +27,55 @@ def load_projector_update(self, config, weight_dir):
if f.endswith(".bin"):
d = torch.load(os.path.join(weight_dir, f), "cpu")
for k, v in d.items():
- if 'vision_proj' in k:
+ if "vision_proj" in k:
projector_weights.append(v.half())
elif "vit.vision_tower." in k:
- new_dict[k[len("vit.vision_tower."):]] = v.half()
+ new_dict[k[len("vit.vision_tower.") :]] = v.half()
self.vision_tower.load_state_dict(new_dict, strict=True)
return projector_weights
- if projector_type == 'identity':
+ if projector_type == "identity":
return []
- raise ValueError(f'Unknown projector type: {projector_type}')
+ raise ValueError(f"Unknown projector type: {projector_type}")
def load_model(self, weight_dir):
config_file = os.path.join(weight_dir, "config.json")
config = json.load(open(config_file))
- self.select_layer = config.get('mm_vision_select_layer', -1)
- self.select_feature = config.get('mm_vision_select_feature', 'patch')
+ self.select_layer = config.get("mm_vision_select_layer", -1)
+ self.select_feature = config.get("mm_vision_select_feature", "patch")
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
- vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336')
+ vision_path = config.get("mm_vision_tower", "openai/clip-vit-large-patch14-336")
if isinstance(vision_path, list):
vision_path = vision_path[0]
if vision_path.startswith("./"):
vision_path = os.path.join(weight_dir, vision_path)
-
- self.image_processor = transforms.Compose([
- transforms.Resize((config["img_size"], config["img_size"]),
- interpolation=InterpolationMode.BICUBIC),
- transforms.ToTensor(),
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
- (0.26862954, 0.26130258, 0.27577711)),
- ])
+
+ self.image_processor = transforms.Compose(
+ [
+ transforms.Resize((config["img_size"], config["img_size"]), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ]
+ )
from transformers import CLIPVisionModel
+
self.vision_tower = CLIPVisionModel.from_pretrained(vision_path)
self.vision_tower.requires_grad_(False)
self.resize_pos(config, vision_path)
self.projector_weights = self.load_projector_update(config, weight_dir)
self.vision_tower = self.vision_tower.half()
- self.device = torch.device('cpu')
+ self.device = torch.device("cpu")
# load projector weights
assert len(self.projector_weights) == self.mlp_depth * 2
def resize_pos(self, config, vision_path):
- mm_vision_tower = vision_path.split('/')[-1]
- vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower)
+ mm_vision_tower = vision_path.split("/")[-1]
+ vision_tower_match = re.match(r"^clip-vit-large-patch(\d+)-(\d+)$", mm_vision_tower)
patch_size = int(vision_tower_match.group(1))
clip_imge_size = int(vision_tower_match.group(2))
-
+
orig_size = clip_imge_size // patch_size
new_size = config["img_size"] // patch_size
if orig_size == new_size:
@@ -82,58 +85,51 @@ def resize_pos(self, config, vision_path):
pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight
pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0)
- if pos_embed_checkpoint.shape[1] == new_size**2 + 1:
+ if pos_embed_checkpoint.shape[1] == new_size ** 2 + 1:
self.is_resize_pos = True
else:
embedding_size = pos_embed_checkpoint.shape[-1]
num_extra_tokens = 1
- new_num = new_size**2 + num_extra_tokens
- print('Position interpolate from %dx%d to %dx%d' %
- (orig_size, orig_size, new_size, new_size))
+ new_num = new_size ** 2 + num_extra_tokens
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
- embedding_size).permute(
- 0, 3, 1, 2)
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
- pos_tokens,
- size=(new_size, new_size),
- mode='bicubic',
- align_corners=False)
+ pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
+ )
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
new_pos_embed = new_pos_embed.squeeze(0)
- self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(
- new_num, 1024)
+ self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(new_num, 1024)
self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(
- new_pos_embed.to(pos_embed_checkpoint.dtype))
- self.vision_tower.vision_model.embeddings.position_ids = torch.arange(
- new_num).expand((1, -1))
+ new_pos_embed.to(pos_embed_checkpoint.dtype)
+ )
+ self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_num).expand((1, -1))
self.is_resize_pos = True
def cuda(self):
self.vision_tower = self.vision_tower.cuda()
for i in range(len(self.projector_weights)):
self.projector_weights[i] = self.projector_weights[i].cuda()
- self.device = torch.device('cuda')
return self
# batch images infer
def forward(self, x):
- x = x.to(device=self.device, dtype=self.vision_tower.dtype)
+ x = x.cuda().to(dtype=self.vision_tower.dtype)
x = self.vision_tower(x, output_hidden_states=True)
x = x.hidden_states[self.select_layer]
- if self.select_feature == 'patch':
+ if self.select_feature == "patch":
x = x[:, 1:].contiguous()
-
+
if len(self.projector_weights) == 0:
return x
-
- B, L, N = x.shape
+
+ B, L, N = x.shape
x = x.view(-1, N)
# mm_project
x = F.linear(
@@ -151,16 +147,31 @@ def forward(self, x):
x = x.view(B, L, -1)
return x
- def encode(self, image_items: List[Union[str, Image.Image]]):
- images = []
- for item in image_items:
- if isinstance(item, Image.Image):
- image = item
- elif item.startswith("http://") or item.startswith("https://"):
- image = Image.open(requests.get(item, stream=True).raw)
+ def encode(self, image_uuids: List):
+ img_tensors = []
+ uuids = []
+ valid_id = 0
+ valid_ids = []
+
+ for i, item in enumerate(image_uuids):
+ item = obtain(item)
+ if isinstance(item, int):
+ uuids.append(item)
+ image_data = read_shm(get_shm_name_data(item))
+ image_data = Image.open(BytesIO(image_data)).convert("RGB")
+ t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
+ img_tensors.append(t)
else:
- image = Image.open(item)
- image = self.image_processor(image.convert('RGB')).unsqueeze(0).to(self.device)
- images.append(image)
- images = torch.cat(images, dim=0)
- return self.forward(images)
\ No newline at end of file
+ raise Exception("Unsupport input types: {} for {}".format(type(item), item))
+
+ cur_num = img_tensors[-1].shape[0]
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ img = torch.cat(img_tensors, dim=0)
+ all_img_embeds = self.forward(img)
+
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py
index b773141e..f0b31456 100644
--- a/lightllm/models/internvl/internvl_visual.py
+++ b/lightllm/models/internvl/internvl_visual.py
@@ -8,22 +8,21 @@
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
-import requests
-from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
-import rpyc
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from io import BytesIO
from lightllm.models.internvl.img_process import load_image
+from rpyc.utils.classic import obtain
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
class InternVLVisionModel:
- def __init__(self, kvargs):
- self.cache_port = kvargs["client_port"]
- self.cache_client = None
+ def __init__(self):
pass
def load_model(self, weight_dir):
assert torch.cuda.is_available()
- self.device = torch.device("cuda")
self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
self.config = json.load(open(os.path.join(weight_dir, "config.json")))
self.model = AutoModel.from_pretrained(
@@ -37,28 +36,32 @@ def load_model(self, weight_dir):
def cuda(self):
return self
- def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]):
+ def encode(self, image_uuids: List):
img_tensors = []
valid_ids = []
valid_id = 0
- # load images to batch tensor
- for i, url in enumerate(image_items):
- if isinstance(url, Image.Image):
- t = load_image(url, max_num=6)
+ uuids = []
+
+ for i, url in enumerate(image_uuids):
+ url = obtain(url)
+ if isinstance(url, int):
+ uuids.append(url)
+ image_data = read_shm(get_shm_name_data(url))
+ image_data = Image.open(BytesIO(image_data))
+ t = load_image(image_data)
img_tensors.append(t)
else:
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
cur_num = img_tensors[-1].shape[0]
-
valid_ids.append([valid_id, valid_id + cur_num])
valid_id += cur_num
if len(img_tensors) <= 0:
return None
- # (b, 3, 224, 224)
imgs = torch.cat(img_tensors, dim=0)
- pixel_values = imgs.to(self.device, dtype=self.dtype)
+ pixel_values = imgs.cuda().to(dtype=self.dtype)
all_img_embeds = self.model.extract_feature(pixel_values)
- return [all_img_embeds[start:end] for start, end in valid_ids]
+
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py
index 776f2959..42deffa2 100644
--- a/lightllm/models/llava/llava_visual.py
+++ b/lightllm/models/llava/llava_visual.py
@@ -5,6 +5,13 @@
from PIL import Image
from typing import List, Union
from safetensors import safe_open
+from rpyc.utils.classic import obtain
+from io import BytesIO
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+from lightllm.utils.log_utils import init_logger
+
+
+logger = init_logger(__name__)
class LlavaVisionModel:
@@ -31,6 +38,7 @@ def load_model(self, weight_dir):
def load_hf_model(self, config, weight_dir):
from transformers import AutoConfig, AutoProcessor, LlavaForConditionalGeneration
+
config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True)
self.select_layer = config.vision_feature_layer
self.select_feature = config.vision_feature_select_strategy
@@ -48,12 +56,16 @@ def load_hf_model(self, config, weight_dir):
self.projector_weights = {}
for f in os.listdir(weight_dir):
if f.endswith(".safetensors"):
- d = safe_open(os.path.join(weight_dir, f), 'pt', 'cpu')
+ d = safe_open(os.path.join(weight_dir, f), "pt", "cpu")
for k in d.keys():
if "multi_modal_projector.linear_1" in k:
- self.projector_weights[k.replace("multi_modal_projector.linear_1", "model.mm_projector.0")] = d.get_tensor(k).half()
+ self.projector_weights[
+ k.replace("multi_modal_projector.linear_1", "model.mm_projector.0")
+ ] = d.get_tensor(k).half()
if "multi_modal_projector.linear_2" in k:
- self.projector_weights[k.replace("multi_modal_projector.linear_2", "model.mm_projector.2")] = d.get_tensor(k).half()
+ self.projector_weights[
+ k.replace("multi_modal_projector.linear_2", "model.mm_projector.2")
+ ] = d.get_tensor(k).half()
def load_bin_model(self, config, weight_dir):
self.select_layer = config.get("mm_vision_select_layer", -2)
@@ -68,6 +80,7 @@ def load_bin_model(self, config, weight_dir):
vision_path = os.path.join(weight_dir, vision_path)
from transformers import CLIPVisionModel, CLIPImageProcessor
+
self.image_processor = CLIPImageProcessor.from_pretrained(vision_path)
self.vision_tower = CLIPVisionModel.from_pretrained(vision_path).half()
@@ -84,13 +97,11 @@ def cuda(self):
self.vision_tower = self.vision_tower.cuda()
for k, v in self.projector_weights.items():
self.projector_weights[k] = v.cuda()
- self.device = torch.device("cuda")
return self
# batch images infer
def forward(self, x):
- x = x.half().to(device=self.device)
-
+ x = x.half().cuda()
x = self.vision_tower(x, output_hidden_states=True)
x = x.hidden_states[self.select_layer]
if self.select_feature == "patch" or self.select_feature == "default":
@@ -113,17 +124,31 @@ def forward(self, x):
x = x.view(B, L, -1)
return x
- def encode(self, image_items: List[Union[str, Image.Image]]):
- images = []
- for item in image_items:
- if isinstance(item, Image.Image):
- image = item
- elif item.startswith("http://") or item.startswith("https://"):
- import requests
- image = Image.open(requests.get(item, stream=True).raw)
+ def encode(self, image_uuids: List):
+ img_tensors = []
+ uuids = []
+ valid_id = 0
+ valid_ids = []
+
+ for i, item in enumerate(image_uuids):
+ item = obtain(item)
+ if isinstance(item, int):
+ uuids.append(item)
+ image_data = read_shm(get_shm_name_data(item))
+ image_data = Image.open(BytesIO(image_data)).convert("RGB")
+ t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
+ img_tensors.append(t)
else:
- image = Image.open(item)
- images.append(image.convert("RGB"))
+ raise Exception("Unsupport input types: {} for {}".format(type(item), item))
+
+ cur_num = img_tensors[-1].shape[0]
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ img = torch.cat(img_tensors, dim=0)
+ all_img_embeds = self.forward(img)
- images = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
- return self.forward(images)
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py
index eced443a..dbe4cb18 100644
--- a/lightllm/models/qwen2_vl/qwen2_visual.py
+++ b/lightllm/models/qwen2_vl/qwen2_visual.py
@@ -28,10 +28,9 @@
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
-import requests
-from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
-import rpyc
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from io import BytesIO
+from rpyc.utils.classic import obtain
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
@@ -112,7 +111,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
- )
+ ).cuda()
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
@@ -382,11 +381,11 @@ def rot_pos_emb(self, grid_thw):
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.to(
dtype=self.get_dtype(),
- device=torch.device("cuda"),
+ device=self.device,
)
grid_thw = grid_thw.to(
dtype=torch.int32,
- device=torch.device("cuda"),
+ device=self.device,
)
hidden_states = self.patch_embed(hidden_states)
@@ -427,16 +426,22 @@ def load_model(self, weight_dir):
self.load_state_dict(weight_dict)
- def encode(self, image_items: List[Union[str, Image.Image]]):
+ def encode(self, image_uuids: List):
img_tensors = []
valid_ids = []
valid_id = 0
img_grids = []
- for i, url in enumerate(image_items):
- if isinstance(url, Image.Image):
- t = get_image(url)
- image_inputs = self.processor.preprocess(images=t, return_tensors="pt")
- pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
+ uuids = []
+
+ for i, url in enumerate(image_uuids):
+ url = obtain(url)
+ if isinstance(url, int):
+ uuids.append(url)
+ image_data = read_shm(get_shm_name_data(url))
+ image_data = Image.open(BytesIO(image_data))
+ image_data = get_image(image_data)
+ image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
+ pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16)
image_grid_thw = image_inputs["image_grid_thw"]
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
@@ -455,10 +460,10 @@ def encode(self, image_items: List[Union[str, Image.Image]]):
imgs = torch.cat(img_tensors, dim=0)
grid_thw = torch.cat(img_grids, dim=0)
- pixel_values = imgs.to(self.device, dtype=torch.float32)
- image_grid_thw = grid_thw.to(self.device)
+ pixel_values = imgs.cuda().to(dtype=torch.float32)
+ image_grid_thw = grid_thw.cuda()
pixel_values = pixel_values.type(self.get_dtype())
- all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw).to(self.device)
+ all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw)
- return [all_img_embeds[start:end] for start, end in valid_ids]
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
index 616703e3..6515ddc4 100644
--- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
+++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
@@ -38,6 +38,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
img_start_loc = 0
img_start_locs = []
+ device = layer_weight.wte_weight_.device
+ dtype = layer_weight.wte_weight_.dtype
+ hidden_size = layer_weight.wte_weight_.shape[1]
for batch_id, p in enumerate(infer_state.multimodal_params):
for img in p["images"]:
# skip the same image
@@ -50,10 +53,6 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
img_start_loc += img["token_num"]
-
- device = layer_weight.wte_weight_.device
- dtype = layer_weight.wte_weight_.dtype
- hidden_size = layer_weight.wte_weight_.shape[1]
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
if len(img_weight) > 0:
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
diff --git a/lightllm/models/qwen_vl/model.py b/lightllm/models/qwen_vl/model.py
index aa94d0b1..8eb91370 100644
--- a/lightllm/models/qwen_vl/model.py
+++ b/lightllm/models/qwen_vl/model.py
@@ -8,7 +8,6 @@
# Warp of the origal tokenizer
class QWenVLTokenizer:
-
def __init__(self, tokenizer, model_cfg):
self.tokenizer = tokenizer
# : 151857
@@ -18,7 +17,7 @@ def __init__(self, tokenizer, model_cfg):
self.image_end_tag = tokenizer.image_end_tag
self.image_end_id = tokenizer.img_end_id
# : 151859
- self.image_length = model_cfg['visual'].get("n_queries", 256)
+ self.image_length = model_cfg["visual"].get("n_queries", 256)
def _list_find(self, input_list, target, start_idx):
cur_list = input_list[start_idx:]
@@ -34,18 +33,18 @@ def _format_prompt(self, prompt):
parts = prompt.split(self.image_start_tag)
prompt = parts[0]
for idx, part in enumerate(parts[1:]):
- prompt += f'Picture {idx + 1}:' + self.image_start_tag + part
+ prompt += f"Picture {idx + 1}:" + self.image_start_tag + part
parts = prompt.split(self.image_end_tag)
prompt = parts[0]
for part in parts[1:]:
- prompt += self.image_end_tag + '\n' + part
+ prompt += self.image_end_tag + "\n" + part
return prompt
# only change the impl of the encode func:
def encode(self, prompt, multimodal_params: MultimodalParams = None):
prompt = unicodedata.normalize("NFC", prompt)
prompt = self._format_prompt(prompt)
- origin_ids = self.tokenizer.tokenizer.encode(prompt, allowed_special='all', disallowed_special=())
+ origin_ids = self.tokenizer.tokenizer.encode(prompt, allowed_special="all", disallowed_special=())
input_ids = []
image_id = 0
@@ -55,7 +54,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None):
start = self._list_find(origin_ids, self.image_start_id, end)
if start == -1:
break
- input_ids.extend(origin_ids[end: start])
+ input_ids.extend(origin_ids[end:start])
end = self._list_find(origin_ids, self.image_end_id, start)
if end == -1:
raise ValueError("Unclosed image token")
@@ -70,14 +69,14 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None):
end += 1
image_id += 1
- input_ids.extend(origin_ids[end: ])
+ input_ids.extend(origin_ids[end:])
if multimodal_params:
image_cnt = len(multimodal_params.images)
assert image_cnt == image_id, "invalid image tag num: {} vs {}!".format(image_cnt, image_id)
return input_ids
def __getattr__(self, name):
- if name != 'encode':
+ if name != "encode":
return getattr(self.tokenizer, name)
return self.encode
diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py
index fa372d41..5cd6b54f 100644
--- a/lightllm/models/qwen_vl/qwen_visual.py
+++ b/lightllm/models/qwen_vl/qwen_visual.py
@@ -11,13 +11,15 @@
from PIL import Image
from typing import Callable, Optional, Sequence, Tuple, List, Union
import numpy as np
-
+import rpyc
+from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
+from rpyc.utils.classic import obtain
def get_abs_pos(abs_pos, tgt_size):
@@ -29,15 +31,21 @@ def get_abs_pos(abs_pos, tgt_size):
dtype = abs_pos.dtype
if src_size != tgt_size:
- return F.interpolate(
- abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
- size=(tgt_size, tgt_size),
- mode="bicubic",
- align_corners=False,
- ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
+ return (
+ F.interpolate(
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
+ size=(tgt_size, tgt_size),
+ mode="bicubic",
+ align_corners=False,
+ )
+ .permute(0, 2, 3, 1)
+ .flatten(0, 2)
+ .to(dtype=dtype)
+ )
else:
return abs_pos
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
@@ -64,7 +72,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
@@ -76,14 +84,14 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.
- omega = 1. / 10000**omega # (D/2,)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
@@ -96,14 +104,8 @@ class Resampler(nn.Module):
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
- def __init__(
- self,
- grid_size,
- embed_dim,
- num_heads,
- kv_dim=None,
- norm_layer=nn.LayerNorm
- ):
+
+ def __init__(self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
@@ -114,7 +116,7 @@ def __init__(
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
- trunc_normal_(self.query, std=.02)
+ trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
@@ -124,12 +126,12 @@ def __init__(
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
-
+
# self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
+ trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
@@ -146,10 +148,8 @@ def forward(self, x, attn_mask=None):
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
- self._repeat(q, N) + self.pos_embed.unsqueeze(1),
- x + pos_embed.unsqueeze(1),
- x,
- attn_mask=attn_mask)[0]
+ self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask
+ )[0]
return out.permute(1, 0, 2)
def _repeat(self, query, N: int):
@@ -163,8 +163,7 @@ class VisualAttention(nn.Module):
and returns output of the same size.
"""
- def __init__(self, embed_dim, num_heads,
- bias=True, kdim=None, vdim=None):
+ def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None):
super(VisualAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
@@ -180,37 +179,37 @@ def __init__(self, embed_dim, num_heads,
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
- assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
+ assert self._qkv_same_embed_dim, "Only Support SelfAttention Currently"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- def forward(self, query, key, value, attn_mask = None):
+ def forward(self, query, key, value, attn_mask=None):
# query/key/value: [sq, b, h]
sq, b, _ = query.size()
- assert torch.allclose(query, key), 'Only Support Self-Attention Currently'
+ assert torch.allclose(query, key), "Only Support Self-Attention Currently"
sk = sq
mixed_x_layer = self.in_proj(query)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
- new_tensor_shape = mixed_x_layer.size()[:-1] + \
- (self.num_attention_heads_per_partition,
- 3 * self.hidden_size_per_attention_head)
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
- query_layer, key_layer, value_layer = mixed_x_layer.split(
- self.hidden_size_per_attention_head, dim=-1)
+ query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
- query_layer = query_layer.view(sq,
- b * self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head).transpose(0, 1)
+ query_layer = query_layer.view(
+ sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ ).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
- key_layer = key_layer.view(sk,
- b * self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head).transpose(0, 1)
+ key_layer = key_layer.view(
+ sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ ).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
@@ -219,24 +218,23 @@ def forward(self, query, key, value, attn_mask = None):
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
- value_layer = value_layer.view(sk,
- b * self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head).transpose(0, 1)
+ value_layer = value_layer.view(
+ sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ ).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
- context_layer = context_layer.view(b,
- self.num_attention_heads_per_partition,
- sq, self.hidden_size_per_attention_head)
+ context_layer = context_layer.view(
+ b, self.num_attention_heads_per_partition, sq, self.hidden_size_per_attention_head
+ )
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
- new_context_layer_shape = context_layer.size()[:-2] + \
- (self.hidden_size_per_partition,)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
@@ -246,13 +244,13 @@ def forward(self, query, key, value, attn_mask = None):
class VisualAttentionBlock(nn.Module):
def __init__(
- self,
- d_model: int,
- n_head: int,
- mlp_ratio: float = 4.0,
- act_layer: Callable = nn.GELU,
- norm_layer: Callable = nn.LayerNorm,
- is_cross_attention: bool = False,
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ is_cross_attention: bool = False,
):
super().__init__()
@@ -263,18 +261,22 @@ def __init__(
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head)
- self.mlp = nn.Sequential(OrderedDict([
- ("c_fc", nn.Linear(d_model, mlp_width)),
- ("gelu", act_layer()),
- ("c_proj", nn.Linear(mlp_width, d_model))
- ]))
+ self.mlp = nn.Sequential(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model)),
+ ]
+ )
+ )
def attention(
- self,
- q_x: torch.Tensor,
- k_x: Optional[torch.Tensor] = None,
- v_x: Optional[torch.Tensor] = None,
- attn_mask: Optional[torch.Tensor] = None,
+ self,
+ q_x: torch.Tensor,
+ k_x: Optional[torch.Tensor] = None,
+ v_x: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
@@ -283,11 +285,11 @@ def attention(
return self.attn(q_x, k_x, v_x, attn_mask=attn_mask)
def forward(
- self,
- q_x: torch.Tensor,
- k_x: Optional[torch.Tensor] = None,
- v_x: Optional[torch.Tensor] = None,
- attn_mask: Optional[torch.Tensor] = None,
+ self,
+ q_x: torch.Tensor,
+ k_x: Optional[torch.Tensor] = None,
+ v_x: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
@@ -299,23 +301,24 @@ def forward(
class TransformerBlock(nn.Module):
def __init__(
- self,
- width: int,
- layers: int,
- heads: int,
- mlp_ratio: float = 4.0,
- act_layer: Callable = nn.GELU,
- norm_layer: Callable = nn.LayerNorm,
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
):
super().__init__()
self.width = width
self.layers = layers
- self.resblocks = nn.ModuleList([
- VisualAttentionBlock(
- width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer)
- for _ in range(layers)
- ])
+ self.resblocks = nn.ModuleList(
+ [
+ VisualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer)
+ for _ in range(layers)
+ ]
+ )
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
@@ -330,18 +333,17 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
class QWenVisionTransformer(nn.Module):
-
def __init__(
- self,
- image_size: int,
- patch_size: int,
- width: int,
- layers: int,
- heads: int,
- mlp_ratio: float,
- n_queries: int = 256,
- output_dim: int = 512,
- **kwargs
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float,
+ n_queries: int = 256,
+ output_dim: int = 512,
+ **kwargs,
):
super().__init__()
image_height, image_width = self.image_size = (image_size, image_size)
@@ -351,14 +353,13 @@ def __init__(
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
- self.image_transform = transforms.Compose([
- transforms.Resize(
- (image_size, image_size),
- interpolation=InterpolationMode.BICUBIC
- ),
- transforms.ToTensor(),
- transforms.Normalize(mean=mean, std=std),
- ])
+ self.image_transform = transforms.Compose(
+ [
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=mean, std=std),
+ ]
+ )
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
@@ -387,7 +388,7 @@ def __init__(
norm_layer=norm_layer,
)
self.ln_post = norm_layer(output_dim)
- self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
+ self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim))
def forward(self, x: torch.Tensor):
x = x.to(
@@ -413,28 +414,42 @@ def forward(self, x: torch.Tensor):
x = x.to(dtype=torch.float16)
return x
-
- def encode(self, image_items: List[Union[str, Image.Image]]):
- images = []
- for item in image_items:
- if isinstance(item, Image.Image):
- image = item
- elif item.startswith("http://") or item.startswith("https://"):
- image = Image.open(requests.get(item, stream=True).raw)
+
+ def encode(self, image_uuids: List):
+ img_tensors = []
+ uuids = []
+ valid_id = 0
+ valid_ids = []
+
+ for i, item in enumerate(image_uuids):
+ item = obtain(item)
+ if isinstance(item, int):
+ uuids.append(item)
+ image_data = read_shm(get_shm_name_data(item))
+ image_data = Image.open(BytesIO(image_data)).convert("RGB")
+ t = self.image_transform(image_data)
+ img_tensors.append(t)
else:
- image = Image.open(item)
- image = image.convert("RGB")
- images.append(self.image_transform(image))
- images = torch.stack(images, dim=0)
- return self(images)
-
+ raise Exception("Unsupport input types: {} for {}".format(type(item), item))
+
+ valid_ids.append([valid_id, valid_id + 1])
+ valid_id += 1
+ if len(img_tensors) <= 0:
+ return None
+
+ pixel_values = torch.stack(img_tensors, dim=0)
+ all_img_embeds = self(pixel_values)
+
+ return all_img_embeds, uuids, valid_ids
+
def load_model(self, weight_dir):
import os
- weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin") ]
+
+ weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
weight_dict = {}
for file_ in weight_files:
f_weight_dict = torch.load(os.path.join(weight_dir, file_), "cpu")
for k, v in f_weight_dict.items():
if "visual" in k:
- weight_dict[k[len("transformer.visual."):]] = v
+ weight_dict[k[len("transformer.visual.") :]] = v
self.load_state_dict(weight_dict)
diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py
index f3df77ab..b43e0f28 100755
--- a/lightllm/server/api_server.py
+++ b/lightllm/server/api_server.py
@@ -476,6 +476,21 @@ def make_argument_parser() -> argparse.ArgumentParser:
"--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value"
)
parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics")
+ parser.add_argument(
+ "--visual_infer_batch_size", type=int, default=4, help="number of images to process in each inference batch"
+ )
+ parser.add_argument(
+ "--visual_gpu_ids", nargs="+", type=int, default=[0], help="List of GPU IDs to use, e.g., 0 1 2"
+ )
+ parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT")
+ parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT")
+ parser.add_argument(
+ "--visual_nccl_ports",
+ nargs="+",
+ type=int,
+ default=[29500],
+ help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502",
+ )
parser.add_argument(
"--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway"
)
@@ -534,6 +549,24 @@ def main():
if args.beam_mode or args.diverse_mode:
assert args.router_token_ratio == 0.0
+ # 检查GPU数量是否足够
+ total_required_gpus = args.visual_dp * args.visual_tp
+ if len(args.visual_gpu_ids) < total_required_gpus:
+ raise ValueError(
+ f"Not enough GPUs specified. You need at least {total_required_gpus}, but got {len(args.visual_gpu_ids)}."
+ )
+ else:
+ args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus]
+
+ # 检查visual_nccl_port数量是否足够
+ if len(args.visual_nccl_ports) < args.visual_dp:
+ raise ValueError(
+ f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, "
+ f"but got ({len(args.visual_nccl_ports)})."
+ )
+ else:
+ args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
+
if not args.splitfuse_mode:
# 普通模式下
if args.batch_max_tokens is None:
@@ -570,9 +603,19 @@ def main():
logger.info(f"all start args:{args}")
- can_use_ports = alloc_can_use_network_port(num=6 + args.tp, used_nccl_port=args.nccl_port)
+ already_uesd_ports = args.visual_nccl_ports + [args.nccl_port]
+ can_use_ports = alloc_can_use_network_port(
+ num=6 + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
+ )
router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
- model_rpc_ports = can_use_ports[6:]
+ model_rpc_ports = can_use_ports[6 : 6 + args.tp]
+ can_use_ports = can_use_ports[6 + args.tp :]
+
+ visual_model_tp_ports = []
+ for _ in range(args.visual_dp):
+ tp_ports_for_dp = can_use_ports[0 : args.visual_tp]
+ can_use_ports = can_use_ports[args.visual_tp :]
+ visual_model_tp_ports.append(tp_ports_for_dp)
if args.enable_multimodal:
start_submodule_processes(
@@ -586,7 +629,7 @@ def main():
start_visual_process,
],
start_args=[
- (args, router_port, visual_port, cache_port),
+ (args, router_port, visual_port, cache_port, visual_model_tp_ports),
],
)
@@ -617,7 +660,6 @@ def main():
(args, detokenization_port, httpserver_port),
],
)
-
if "s3://" in args.model_dir:
from lightllm.utils.petrel_helper import s3_model_clear
diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py
index f3158dcb..228f7561 100644
--- a/lightllm/server/visualserver/manager.py
+++ b/lightllm/server/visualserver/manager.py
@@ -23,8 +23,8 @@ def __init__(
args,
router_port,
visual_port,
- client_port,
- infer_batch_size=4,
+ cache_port,
+ visual_model_rpc_ports,
):
context = zmq.asyncio.Context(2)
self.send_to_router = context.socket(zmq.PUSH)
@@ -32,33 +32,45 @@ def __init__(
self.recv_from_httpserver = context.socket(zmq.PULL)
self.recv_from_httpserver.bind(f"tcp://127.0.0.1:{visual_port}")
- self.cache_client = rpyc.connect("localhost", client_port)
- self.client_port = client_port
+ self.cache_client = rpyc.connect("localhost", cache_port)
+ self.cache_port = cache_port
self.waiting_reqs = []
self.model_weightdir = args.model_dir
self.tp_world_size = args.tp
- self.world_size = 1
- self.infer_batch_size = infer_batch_size
+ self.vit_dp = args.visual_dp
+ self.vit_tp = args.visual_tp
+ self.infer_batch_size = args.visual_infer_batch_size
self.trust_remote_code = args.trust_remote_code
self.args = args
+ self.visual_model_rpc_ports = visual_model_rpc_ports
async def wait_to_model_ready(self):
- self.model_rpcs: List[VisualModelRpcClient] = []
- for rank_id in range(self.world_size):
- rpc_model = await start_model_process(world_size=self.world_size)
- self.model_rpcs.append(rpc_model)
+ self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)]
+
+ for dp_rank_id in range(self.vit_dp):
+ tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id]
+ for tp_rank_id in range(self.vit_tp):
+ rpc_model = await start_model_process(port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp)
+ self.model_rpcs[dp_rank_id].append(rpc_model)
init_model_ret = []
- for rank_id in range(self.world_size): # async init model process
- kvargs = {
- "weight_dir": self.model_weightdir,
- "trust_remote_code": self.trust_remote_code,
- "client_port": self.client_port,
- "rank_id": rank_id,
- "data_type": self.args.data_type,
- }
- init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs))
+ for dp_rank_id in range(self.vit_dp): # async init model process
+ for tp_rank_id in range(self.vit_tp):
+ kvargs = {
+ "weight_dir": self.model_weightdir,
+ "trust_remote_code": self.trust_remote_code,
+ "vit_dp": self.vit_dp,
+ "vit_tp": self.vit_tp,
+ "cache_port": self.cache_port,
+ "tp_rank_id": tp_rank_id,
+ "dp_rank_id": dp_rank_id,
+ "vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id,
+ "data_type": self.args.data_type,
+ "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id],
+ "visual_gpu_ids": self.args.visual_gpu_ids,
+ }
+ init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs))
await asyncio.gather(*init_model_ret)
return
@@ -73,25 +85,16 @@ async def infer_imgs(self, uuids):
if len(uuids) == 0:
return
# uuids -> PIL Images
- images = []
- for uid in uuids:
- image_data = read_shm(get_shm_name_data(uid))
- images.append(Image.open(BytesIO(image_data)))
- # print(" + got pil image:", images[-1].size, images[-1].mode)
- rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)]
- ans = await asyncio.gather(*rets)
- if self.world_size != 1:
- img_embed = obtain(ans[0])
- else:
- img_embed = ans[0]
- torch.cuda.synchronize()
- # b = time.time()
- for i in range(len(uuids)):
- # print(" + set_item_embed:", uuids[i], img_embed[i].shape)
- if not self.cache_client.root.get_item_embed(uuids[i]):
- cur_embed_bytes = tensor2bytes(img_embed[i])
- create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
- self.cache_client.root.set_item_embed(uuids[i])
+ tasks = []
+ for vit_dp_rank in range(self.vit_dp):
+ assigned_uuids = [uuids[i] for i in range(vit_dp_rank, len(uuids), self.vit_dp)]
+ if assigned_uuids:
+ for vit_tp_rank in range(self.vit_tp):
+ task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_uuids))
+ tasks.append(task)
+
+ # rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)]
+ await asyncio.gather(*tasks)
return
async def loop_for_fwd(self):
@@ -112,12 +115,10 @@ async def loop_for_fwd(self):
if len(uuids_need_infer) > 0:
reqs_need_infer.append(req)
else:
- # print(" + no need need infer, send to router...")
self.send_to_router.send_pyobj(req)
await self.infer_imgs(uuids_need_infer)
for req in reqs_need_infer:
- # print(" + after infer_imgs, send to router...")
self.send_to_router.send_pyobj(req)
async def loop_for_netio_req(self):
@@ -140,7 +141,7 @@ def clean_up(self):
return
-def start_visual_process(args, router_port, visual_port, client_port, pipe_writer):
+def start_visual_process(args, router_port, visual_port, cache_port, model_rpc_ports, pipe_writer):
# 注册graceful 退出的处理
from lightllm.utils.graceful_utils import graceful_registry
import inspect
@@ -148,7 +149,7 @@ def start_visual_process(args, router_port, visual_port, client_port, pipe_write
graceful_registry(inspect.currentframe().f_code.co_name)
try:
- visualserver = VisualManager(args, router_port, visual_port, client_port)
+ visualserver = VisualManager(args, router_port, visual_port, cache_port, model_rpc_ports)
asyncio.run(visualserver.wait_to_model_ready())
except Exception as e:
import traceback
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index 55aebdef..be341d49 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -2,6 +2,7 @@
import numpy as np
import rpyc
import torch
+import os
from datetime import timedelta
from typing import Dict, List, Tuple
from transformers.configuration_utils import PretrainedConfig
@@ -12,29 +13,39 @@
from lightllm.models.internlm_xcomposer.internlm_visual import InternVisionModel
from lightllm.models.internvl.internvl_visual import InternVLVisionModel
from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel
+from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
class VisualModelRpcServer(rpyc.Service):
def exposed_init_model(self, kvargs):
- # 注册graceful 退出的处理
- from lightllm.utils.graceful_utils import graceful_registry
- import inspect
-
- graceful_registry(inspect.currentframe().f_code.co_name)
-
- # import torch
- # import torch.distributed as dist
- # world_size = kvargs["world_size"]
- # if world_size != 1:
- # kvargs = obtain(kvargs)
- # world_size = kvargs["world_size"]
- # dist.init_process_group('nccl', init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}',
- # rank=self.tp_rank, world_size=world_size)
- # torch.cuda.set_device(self.tp_rank)
+ import torch
+ import torch.distributed as dist
+
+ self.vit_dp = kvargs["vit_dp"]
+ self.vit_tp = kvargs["vit_tp"]
+ self.dp_rank_id = kvargs["dp_rank_id"]
+ self.tp_rank_id = kvargs["tp_rank_id"]
+ self.cache_port = kvargs["cache_port"]
weight_dir = kvargs["weight_dir"]
+ visual_gpu_ids = kvargs["visual_gpu_ids"]
+ visual_nccl_port = kvargs["visual_nccl_port"]
+ self.vit_rank_id = kvargs["vit_rank_id"]
+ self.cache_client = rpyc.connect("localhost", self.cache_port)
+
+ torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id])
+ if self.vit_tp != 1:
+ dist.init_process_group(
+ backend="nccl",
+ init_method=f"tcp://127.0.0.1:{visual_nccl_port}",
+ rank=self.tp_rank_id,
+ world_size=self.vit_tp,
+ )
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)
+
+ if self.vit_tp != 1:
+ raise ValueError(f"ERROR: Not support vit_tp value: {self.vit_tp}")
try:
self.model_type = model_cfg["model_type"]
if self.model_type == "qwen":
@@ -46,14 +57,10 @@ def exposed_init_model(self, kvargs):
elif self.model_type == "internlmxcomposer2":
self.model = InternVisionModel()
elif self.model_type == "internvl_chat":
- # tp_rank = kvargs['rank_id']
- client_port = kvargs["client_port"]
- data_type = kvargs["data_type"]
- model_kvargs = {"weight_dir": weight_dir, "client_port": client_port, "data_type": data_type}
- self.model = InternVLVisionModel(model_kvargs)
-
+ self.model = InternVLVisionModel()
else:
raise Exception(f"can not support {self.model_type} now")
+
self.model.load_model(weight_dir)
self.model = self.model.cuda()
except Exception as e:
@@ -69,20 +76,33 @@ def exposed_init_model(self, kvargs):
# @calculate_time(show=True, min_cost_ms=150)
@torch.no_grad()
- def forward(self, images):
- return self.model.encode(images)
+ def forward(self, images_uuids):
+ return self.model.encode(images_uuids)
# @calculate_time(show=False, min_cost_ms=300)
- def exposed_encode(self, images):
- return self.forward(images)
+ def exposed_encode(self, images_uuids):
+ all_img_embeds, uuids, valid_ids = self.forward(images_uuids)
+ all_img_embeds = all_img_embeds.to(torch.device("cpu"))
+ if self.tp_rank_id == 0:
+ if len(uuids) == 0:
+ return [all_img_embeds[start:end] for start, end in valid_ids]
+ else:
+ for i in range(len(uuids)):
+ uid = uuids[i]
+ if not self.cache_client.root.get_item_embed(uid):
+ start, end = valid_ids[i]
+ cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
+ create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
+ self.cache_client.root.set_item_embed(uuids[i])
+ return
class VisualModelRpcClient:
- def __init__(self, model_rpc, world_size, rpc_server_process=None):
+ def __init__(self, model_rpc, vit_tp, rpc_server_process=None):
self.model: VisualModelRpcServer = model_rpc
- self.world_size = world_size
+ self.vit_tp = vit_tp
self.rpc_server_process = rpc_server_process
- self.use_rpc = self.world_size != 1
+ self.use_rpc = True
if self.use_rpc:
def async_wrap(f):
@@ -111,14 +131,44 @@ async def init_model(self, kvargs):
else:
return
- async def encode(self, images):
- ans = self._encode(images)
+ async def encode(self, uuids):
+ ans = self._encode(uuids)
if self.use_rpc:
return await ans
else:
return ans
-async def start_model_process(world_size):
- if world_size == 1:
- return VisualModelRpcClient(VisualModelRpcServer(), world_size)
+def _init_env(port):
+ # 注册graceful 退出的处理
+ from lightllm.utils.graceful_utils import graceful_registry
+ import inspect
+
+ graceful_registry(inspect.currentframe().f_code.co_name)
+
+ from rpyc.utils.server import ThreadedServer
+
+ t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True})
+ t.start()
+ return
+
+
+async def start_model_process(port, vit_tp):
+ import multiprocessing
+
+ proc = multiprocessing.Process(target=_init_env, args=(port,))
+ proc.start()
+ await asyncio.sleep(2)
+ repeat_count = 0
+ while repeat_count < 20:
+ try:
+ con = rpyc.connect("localhost", port, config={"allow_pickle": True})
+ break
+ except BaseException:
+ await asyncio.sleep(1)
+ repeat_count += 1
+ if repeat_count == 20:
+ raise Exception("init rpc env error!")
+
+ assert proc.is_alive()
+ return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc)
diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py
index acac8fda..b58e0b09 100644
--- a/lightllm/utils/net_utils.py
+++ b/lightllm/utils/net_utils.py
@@ -1,12 +1,12 @@
import socket
-def alloc_can_use_network_port(num=3, used_nccl_port=None):
+def alloc_can_use_network_port(num=3, used_nccl_ports=None):
port_list = []
for port in range(10000, 65536):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
result = s.connect_ex(("localhost", port))
- if result != 0 and port != used_nccl_port:
+ if result != 0 and port not in used_nccl_ports:
port_list.append(port)
if len(port_list) == num: