From 8cb078b2ca0529c536c47398b7b64e543b6bf344 Mon Sep 17 00:00:00 2001 From: SangChengC Date: Fri, 9 Aug 2024 06:57:29 +0000 Subject: [PATCH 1/8] update_readme_internvl --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index b6e6b9fc..0e0441fd 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram - [Baichuan2-13b](https://github.com/baichuan-inc/Baichuan2) - [Baichuan-13b](https://github.com/baichuan-inc/Baichuan-13B) - [InternLM-7b](https://github.com/InternLM/InternLM) +- [InternVL-Chat](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5) - [Yi-34b](https://huggingface.co/01-ai/Yi-34B) - [Qwen-VL](https://huggingface.co/Qwen/Qwen-VL) - [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat) @@ -62,6 +63,10 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram > InternLM needs to set the parameter '--trust_remote_code'. +> InternVL-Chat(Phi3) needs to set the parameter '--eos_id 32007 --trust_remote_code'. + +> InternVL-Chat(InternLM2) needs to set the parameter '--eos_id 92542 --trust_remote_code'. + > Stablelm needs to set the parameter '--trust_remote_code'. > Phi-3 only supports Mini and Small. From 4ab48ac497f690d058b9c6c9a77434760bb5893e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 23 Aug 2024 19:14:36 +0800 Subject: [PATCH 2/8] data parallelism for vision model --- .../internlm_xcomposer/internlm_visual.py | 93 ++++---- lightllm/models/internvl/internvl_visual.py | 12 +- lightllm/models/llava/llava_visual.py | 26 +- .../qwen_vl/layer_infer/pre_layer_infer.py | 9 +- lightllm/models/qwen_vl/qwen_visual.py | 224 +++++++++--------- lightllm/server/api_server.py | 7 +- lightllm/server/embed_cache/utils.py | 4 +- lightllm/server/visualserver/manager.py | 49 ++-- .../visualserver/model_infer/model_rpc.py | 81 +++++-- 9 files changed, 291 insertions(+), 214 deletions(-) diff --git a/lightllm/models/internlm_xcomposer/internlm_visual.py b/lightllm/models/internlm_xcomposer/internlm_visual.py index 868aab0d..aed4dc93 100644 --- a/lightllm/models/internlm_xcomposer/internlm_visual.py +++ b/lightllm/models/internlm_xcomposer/internlm_visual.py @@ -7,17 +7,19 @@ from typing import List, Union from torchvision import transforms from torchvision.transforms.functional import InterpolationMode +from rpyc.utils.classic import obtain class InternVisionModel: - - def __init__(self): + def __init__(self, kvargs): + self.tp_rank_ = kvargs["tp_rank"] + self.world_size_ = kvargs["world_size"] pass def load_projector_update(self, config, weight_dir): - projector_type = config.get("projector_type", 'mlp2x_gelu') + projector_type = config.get("projector_type", "mlp2x_gelu") projector_weights = [] - mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) if mlp_gelu_match: self.mlp_depth = int(mlp_gelu_match.group(1)) new_dict = {} @@ -25,54 +27,55 @@ def load_projector_update(self, config, weight_dir): if f.endswith(".bin"): d = torch.load(os.path.join(weight_dir, f), "cpu") for k, v in d.items(): - if 'vision_proj' in k: + if "vision_proj" in k: projector_weights.append(v.half()) elif "vit.vision_tower." in k: - new_dict[k[len("vit.vision_tower."):]] = v.half() + new_dict[k[len("vit.vision_tower.") :]] = v.half() self.vision_tower.load_state_dict(new_dict, strict=True) return projector_weights - if projector_type == 'identity': + if projector_type == "identity": return [] - raise ValueError(f'Unknown projector type: {projector_type}') + raise ValueError(f"Unknown projector type: {projector_type}") def load_model(self, weight_dir): config_file = os.path.join(weight_dir, "config.json") config = json.load(open(config_file)) - self.select_layer = config.get('mm_vision_select_layer', -1) - self.select_feature = config.get('mm_vision_select_feature', 'patch') + self.select_layer = config.get("mm_vision_select_layer", -1) + self.select_feature = config.get("mm_vision_select_feature", "patch") # load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir - vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336') + vision_path = config.get("mm_vision_tower", "openai/clip-vit-large-patch14-336") if isinstance(vision_path, list): vision_path = vision_path[0] if vision_path.startswith("./"): vision_path = os.path.join(weight_dir, vision_path) - - self.image_processor = transforms.Compose([ - transforms.Resize((config["img_size"], config["img_size"]), - interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711)), - ]) + + self.image_processor = transforms.Compose( + [ + transforms.Resize((config["img_size"], config["img_size"]), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) from transformers import CLIPVisionModel + self.vision_tower = CLIPVisionModel.from_pretrained(vision_path) self.vision_tower.requires_grad_(False) self.resize_pos(config, vision_path) self.projector_weights = self.load_projector_update(config, weight_dir) self.vision_tower = self.vision_tower.half() - self.device = torch.device('cpu') + self.device = torch.device("cpu") # load projector weights assert len(self.projector_weights) == self.mlp_depth * 2 def resize_pos(self, config, vision_path): - mm_vision_tower = vision_path.split('/')[-1] - vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower) + mm_vision_tower = vision_path.split("/")[-1] + vision_tower_match = re.match(r"^clip-vit-large-patch(\d+)-(\d+)$", mm_vision_tower) patch_size = int(vision_tower_match.group(1)) clip_imge_size = int(vision_tower_match.group(2)) - + orig_size = clip_imge_size // patch_size new_size = config["img_size"] // patch_size if orig_size == new_size: @@ -82,43 +85,37 @@ def resize_pos(self, config, vision_path): pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0) - if pos_embed_checkpoint.shape[1] == new_size**2 + 1: + if pos_embed_checkpoint.shape[1] == new_size ** 2 + 1: self.is_resize_pos = True else: embedding_size = pos_embed_checkpoint.shape[-1] num_extra_tokens = 1 - new_num = new_size**2 + num_extra_tokens - print('Position interpolate from %dx%d to %dx%d' % - (orig_size, orig_size, new_size, new_size)) + new_num = new_size ** 2 + num_extra_tokens + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, - embedding_size).permute( - 0, 3, 1, 2) + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( - pos_tokens, - size=(new_size, new_size), - mode='bicubic', - align_corners=False) + pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False + ) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) new_pos_embed = new_pos_embed.squeeze(0) - self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding( - new_num, 1024) + self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(new_num, 1024) self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter( - new_pos_embed.to(pos_embed_checkpoint.dtype)) - self.vision_tower.vision_model.embeddings.position_ids = torch.arange( - new_num).expand((1, -1)) + new_pos_embed.to(pos_embed_checkpoint.dtype) + ) + self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_num).expand((1, -1)) self.is_resize_pos = True def cuda(self): self.vision_tower = self.vision_tower.cuda() for i in range(len(self.projector_weights)): self.projector_weights[i] = self.projector_weights[i].cuda() - self.device = torch.device('cuda') + self.device = torch.device(f"cuda:{self.tp_rank_}") return self # batch images infer @@ -127,13 +124,13 @@ def forward(self, x): x = self.vision_tower(x, output_hidden_states=True) x = x.hidden_states[self.select_layer] - if self.select_feature == 'patch': + if self.select_feature == "patch": x = x[:, 1:].contiguous() - + if len(self.projector_weights) == 0: return x - - B, L, N = x.shape + + B, L, N = x.shape x = x.view(-1, N) # mm_project x = F.linear( @@ -154,13 +151,17 @@ def forward(self, x): def encode(self, image_items: List[Union[str, Image.Image]]): images = [] for item in image_items: + if self.world_size_ != 1: + item = obtain(item) if isinstance(item, Image.Image): image = item elif item.startswith("http://") or item.startswith("https://"): + import requests + image = Image.open(requests.get(item, stream=True).raw) else: image = Image.open(item) - image = self.image_processor(image.convert('RGB')).unsqueeze(0).to(self.device) + image = self.image_processor(image.convert("RGB")).unsqueeze(0).to(self.device) images.append(image) images = torch.cat(images, dim=0) - return self.forward(images) \ No newline at end of file + return self.forward(images) diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py index b773141e..1bc3c82a 100644 --- a/lightllm/models/internvl/internvl_visual.py +++ b/lightllm/models/internvl/internvl_visual.py @@ -13,17 +13,21 @@ import rpyc from io import BytesIO from lightllm.models.internvl.img_process import load_image +from rpyc.utils.classic import obtain +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class InternVLVisionModel: def __init__(self, kvargs): - self.cache_port = kvargs["client_port"] - self.cache_client = None + self.tp_rank_ = kvargs["tp_rank"] + self.world_size_ = kvargs["world_size"] pass def load_model(self, weight_dir): assert torch.cuda.is_available() - self.device = torch.device("cuda") + self.device = torch.device(f"cuda:{self.tp_rank_}") self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 self.config = json.load(open(os.path.join(weight_dir, "config.json"))) self.model = AutoModel.from_pretrained( @@ -43,6 +47,8 @@ def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]): valid_id = 0 # load images to batch tensor for i, url in enumerate(image_items): + if self.world_size_ != 1: + url = obtain(url) if isinstance(url, Image.Image): t = load_image(url, max_num=6) img_tensors.append(t) diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 776f2959..92e030e2 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -5,10 +5,17 @@ from PIL import Image from typing import List, Union from safetensors import safe_open +from rpyc.utils.classic import obtain +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) class LlavaVisionModel: - def __init__(self): + def __init__(self, kvargs): + self.tp_rank_ = kvargs["tp_rank"] + self.world_size_ = kvargs["world_size"] pass def load_model(self, weight_dir): @@ -31,6 +38,7 @@ def load_model(self, weight_dir): def load_hf_model(self, config, weight_dir): from transformers import AutoConfig, AutoProcessor, LlavaForConditionalGeneration + config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True) self.select_layer = config.vision_feature_layer self.select_feature = config.vision_feature_select_strategy @@ -48,12 +56,16 @@ def load_hf_model(self, config, weight_dir): self.projector_weights = {} for f in os.listdir(weight_dir): if f.endswith(".safetensors"): - d = safe_open(os.path.join(weight_dir, f), 'pt', 'cpu') + d = safe_open(os.path.join(weight_dir, f), "pt", "cpu") for k in d.keys(): if "multi_modal_projector.linear_1" in k: - self.projector_weights[k.replace("multi_modal_projector.linear_1", "model.mm_projector.0")] = d.get_tensor(k).half() + self.projector_weights[ + k.replace("multi_modal_projector.linear_1", "model.mm_projector.0") + ] = d.get_tensor(k).half() if "multi_modal_projector.linear_2" in k: - self.projector_weights[k.replace("multi_modal_projector.linear_2", "model.mm_projector.2")] = d.get_tensor(k).half() + self.projector_weights[ + k.replace("multi_modal_projector.linear_2", "model.mm_projector.2") + ] = d.get_tensor(k).half() def load_bin_model(self, config, weight_dir): self.select_layer = config.get("mm_vision_select_layer", -2) @@ -68,6 +80,7 @@ def load_bin_model(self, config, weight_dir): vision_path = os.path.join(weight_dir, vision_path) from transformers import CLIPVisionModel, CLIPImageProcessor + self.image_processor = CLIPImageProcessor.from_pretrained(vision_path) self.vision_tower = CLIPVisionModel.from_pretrained(vision_path).half() @@ -84,7 +97,7 @@ def cuda(self): self.vision_tower = self.vision_tower.cuda() for k, v in self.projector_weights.items(): self.projector_weights[k] = v.cuda() - self.device = torch.device("cuda") + self.device = torch.device(f"cuda:{self.tp_rank_}") return self # batch images infer @@ -116,10 +129,13 @@ def forward(self, x): def encode(self, image_items: List[Union[str, Image.Image]]): images = [] for item in image_items: + if self.world_size_ != 1: + item = obtain(item) if isinstance(item, Image.Image): image = item elif item.startswith("http://") or item.startswith("https://"): import requests + image = Image.open(requests.get(item, stream=True).raw) else: image = Image.open(item) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 616703e3..3a099b12 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -38,6 +38,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_loc = 0 img_start_locs = [] + device = layer_weight.wte_weight_.device + dtype = layer_weight.wte_weight_.dtype + hidden_size = layer_weight.wte_weight_.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"]: # skip the same image @@ -45,15 +48,11 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei continue # pull the img_embeds by uid from shm data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).reshape(img["token_num"], -1)) + img_weight.append(bytes2tensor(data, device).reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) img_start_loc += img["token_num"] - - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype - hidden_size = layer_weight.wte_weight_.shape[1] out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) if len(img_weight) > 0: img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index fa372d41..a5529e66 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -18,6 +18,7 @@ from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode +from rpyc.utils.classic import obtain def get_abs_pos(abs_pos, tgt_size): @@ -29,15 +30,21 @@ def get_abs_pos(abs_pos, tgt_size): dtype = abs_pos.dtype if src_size != tgt_size: - return F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), - size=(tgt_size, tgt_size), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + return ( + F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .flatten(0, 2) + .to(dtype=dtype) + ) else: return abs_pos + # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ @@ -64,7 +71,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb @@ -76,14 +83,14 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb @@ -96,14 +103,8 @@ class Resampler(nn.Module): Outputs: A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__( - self, - grid_size, - embed_dim, - num_heads, - kv_dim=None, - norm_layer=nn.LayerNorm - ): + + def __init__(self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm): super().__init__() self.num_queries = grid_size ** 2 self.embed_dim = embed_dim @@ -114,7 +115,7 @@ def __init__( ).requires_grad_(False) self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) - trunc_normal_(self.query, std=.02) + trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) @@ -124,12 +125,12 @@ def __init__( self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) - + # self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -146,10 +147,8 @@ def forward(self, x, attn_mask=None): N = x.shape[1] q = self.ln_q(self.query) out = self.attn( - self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask)[0] + self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask + )[0] return out.permute(1, 0, 2) def _repeat(self, query, N: int): @@ -163,8 +162,7 @@ class VisualAttention(nn.Module): and returns output of the same size. """ - def __init__(self, embed_dim, num_heads, - bias=True, kdim=None, vdim=None): + def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): super(VisualAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -180,37 +178,37 @@ def __init__(self, embed_dim, num_heads, self.hidden_size_per_partition = embed_dim # Strided linear layer. - assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' + assert self._qkv_same_embed_dim, "Only Support SelfAttention Currently" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - def forward(self, query, key, value, attn_mask = None): + def forward(self, query, key, value, attn_mask=None): # query/key/value: [sq, b, h] sq, b, _ = query.size() - assert torch.allclose(query, key), 'Only Support Self-Attention Currently' + assert torch.allclose(query, key), "Only Support Self-Attention Currently" sk = sq mixed_x_layer = self.in_proj(query) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - query_layer, key_layer, value_layer = mixed_x_layer.split( - self.hidden_size_per_attention_head, dim=-1) + query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, dim=-1) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(sq, - b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(0, 1) # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(sk, - b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + key_layer = key_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(0, 1) q_scaled = query_layer / self.norm_factor if attn_mask is not None: @@ -219,24 +217,23 @@ def forward(self, query, key, value, attn_mask = None): attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) attention_probs = attention_probs.softmax(dim=-1) - value_layer = value_layer.view(sk, - b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + value_layer = value_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(0, 1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer) # change view [b, np, sq, hn] - context_layer = context_layer.view(b, - self.num_attention_heads_per_partition, - sq, self.hidden_size_per_attention_head) + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, self.hidden_size_per_attention_head + ) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) output = self.out_proj(context_layer) @@ -246,13 +243,13 @@ def forward(self, query, key, value, attn_mask = None): class VisualAttentionBlock(nn.Module): def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, - is_cross_attention: bool = False, + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + is_cross_attention: bool = False, ): super().__init__() @@ -263,18 +260,22 @@ def __init__( self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.attn = VisualAttention(d_model, n_head) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) def attention( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x @@ -283,11 +284,11 @@ def attention( return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) def forward( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None @@ -299,23 +300,24 @@ def forward( class TransformerBlock(nn.Module): def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, ): super().__init__() self.width = width self.layers = layers - self.resblocks = nn.ModuleList([ - VisualAttentionBlock( - width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) - ]) + self.resblocks = nn.ModuleList( + [ + VisualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ] + ) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype @@ -330,19 +332,21 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class QWenVisionTransformer(nn.Module): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - n_queries: int = 256, - output_dim: int = 512, - **kwargs + self, + kvargs, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + **kwargs ): + self.tp_rank_ = kvargs["tp_rank"] + self.world_size_ = kvargs["world_size"] super().__init__() image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) @@ -351,14 +355,13 @@ def __init__( mean = (0.48145466, 0.4578275, 0.40821073) std = (0.26862954, 0.26130258, 0.27577711) - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC - ), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) + self.image_transform = transforms.Compose( + [ + transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) @@ -387,7 +390,7 @@ def __init__( norm_layer=norm_layer, ) self.ln_post = norm_layer(output_dim) - self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) + self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim)) def forward(self, x: torch.Tensor): x = x.to( @@ -413,10 +416,12 @@ def forward(self, x: torch.Tensor): x = x.to(dtype=torch.float16) return x - + def encode(self, image_items: List[Union[str, Image.Image]]): images = [] for item in image_items: + if self.world_size_ != 1: + item = obtain(item) if isinstance(item, Image.Image): image = item elif item.startswith("http://") or item.startswith("https://"): @@ -427,14 +432,15 @@ def encode(self, image_items: List[Union[str, Image.Image]]): images.append(self.image_transform(image)) images = torch.stack(images, dim=0) return self(images) - + def load_model(self, weight_dir): import os - weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin") ] + + weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] weight_dict = {} for file_ in weight_files: f_weight_dict = torch.load(os.path.join(weight_dir, file_), "cpu") for k, v in f_weight_dict.items(): if "visual" in k: - weight_dict[k[len("transformer.visual."):]] = v + weight_dict[k[len("transformer.visual.") :]] = v self.load_state_dict(weight_dict) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 9d414251..ce7884a4 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -504,9 +504,10 @@ def main(): logger.info(f"all start args:{args}") - can_use_ports = alloc_can_use_network_port(num=6 + args.tp, used_nccl_port=args.nccl_port) + can_use_ports = alloc_can_use_network_port(num=6 + args.tp * 2, used_nccl_port=args.nccl_port) router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6] - model_rpc_ports = can_use_ports[6:] + model_rpc_ports = can_use_ports[6 : 6 + args.tp] + visual_model_rpc_ports = can_use_ports[6 + args.tp :] if args.enable_multimodal: start_submodule_processes( @@ -555,7 +556,7 @@ def main(): start_visual_process, ], start_args=[ - (args, router_port, visual_port, cache_port), + (args, router_port, visual_port, cache_port, visual_model_rpc_ports), ], ) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 78bce1ea..20c78724 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -13,9 +13,9 @@ def tensor2bytes(t): return buf.read() -def bytes2tensor(b): +def bytes2tensor(b, device): # return torch.from_numpy(np.frombuffer(b, dtype=np.float16)).cuda() - return torch.load(BytesIO(b)) + return torch.load(BytesIO(b), map_location=device) def create_shm(name, data): diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index f3158dcb..d6e2d5a9 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -24,6 +24,7 @@ def __init__( router_port, visual_port, client_port, + model_rpc_ports, infer_batch_size=4, ): context = zmq.asyncio.Context(2) @@ -37,16 +38,17 @@ def __init__( self.waiting_reqs = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp - self.world_size = 1 + self.world_size = args.tp self.infer_batch_size = infer_batch_size self.trust_remote_code = args.trust_remote_code self.args = args + self.model_rpcs_ports = model_rpc_ports async def wait_to_model_ready(self): self.model_rpcs: List[VisualModelRpcClient] = [] for rank_id in range(self.world_size): - rpc_model = await start_model_process(world_size=self.world_size) + rpc_model = await start_model_process(port=self.model_rpcs_ports[rank_id], world_size=self.world_size) self.model_rpcs.append(rpc_model) init_model_ret = [] @@ -54,9 +56,11 @@ async def wait_to_model_ready(self): kvargs = { "weight_dir": self.model_weightdir, "trust_remote_code": self.trust_remote_code, + "world_size": self.world_size, "client_port": self.client_port, "rank_id": rank_id, "data_type": self.args.data_type, + "nccl_port": self.args.nccl_port, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) @@ -78,22 +82,33 @@ async def infer_imgs(self, uuids): image_data = read_shm(get_shm_name_data(uid)) images.append(Image.open(BytesIO(image_data))) # print(" + got pil image:", images[-1].size, images[-1].mode) - rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)] - ans = await asyncio.gather(*rets) - if self.world_size != 1: - img_embed = obtain(ans[0]) - else: - img_embed = ans[0] + tasks = [] + for tp_rank in range(self.world_size): + assigned_images = [images[i] for i in range(tp_rank, len(images), self.world_size)] + assigned_uuids = [uuids[i] for i in range(tp_rank, len(uuids), self.world_size)] + if assigned_images: + task = asyncio.create_task(self.encode_and_store(tp_rank, assigned_images, assigned_uuids)) + tasks.append(task) + + # rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)] + await asyncio.gather(*tasks) torch.cuda.synchronize() - # b = time.time() - for i in range(len(uuids)): - # print(" + set_item_embed:", uuids[i], img_embed[i].shape) - if not self.cache_client.root.get_item_embed(uuids[i]): - cur_embed_bytes = tensor2bytes(img_embed[i]) - create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) - self.cache_client.root.set_item_embed(uuids[i]) return + async def encode_and_store(self, tp_rank, assigned_images, assigned_uuids): + # aynsc vit-encode + img_embeds = [] + result = await self.model_rpcs[tp_rank].encode(assigned_images) + img_embeds.extend(obtain(result)) + print(f"cuda{tp_rank} is work on uuid {assigned_uuids}") + print("len of img_embeds is", len(img_embeds)) + # write img_embed to shm + for i, uid in enumerate(assigned_uuids): + if not self.cache_client.root.get_item_embed(uid): + cur_embed_bytes = tensor2bytes(img_embeds[i]) + create_shm(get_shm_name_embed(uid), cur_embed_bytes) + self.cache_client.root.set_item_embed(uid) + async def loop_for_fwd(self): while True: if len(self.waiting_reqs) == 0: @@ -140,7 +155,7 @@ def clean_up(self): return -def start_visual_process(args, router_port, visual_port, client_port, pipe_writer): +def start_visual_process(args, router_port, visual_port, client_port, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 from lightllm.utils.graceful_utils import graceful_registry import inspect @@ -148,7 +163,7 @@ def start_visual_process(args, router_port, visual_port, client_port, pipe_write graceful_registry(inspect.currentframe().f_code.co_name) try: - visualserver = VisualManager(args, router_port, visual_port, client_port) + visualserver = VisualManager(args, router_port, visual_port, client_port, model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: import traceback diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 264c83c3..ed271028 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -2,6 +2,7 @@ import numpy as np import rpyc import torch +import os from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -17,40 +18,40 @@ class VisualModelRpcServer(rpyc.Service): def exposed_init_model(self, kvargs): - # 注册graceful 退出的处理 - from lightllm.utils.graceful_utils import graceful_registry - import inspect - - graceful_registry(inspect.currentframe().f_code.co_name) - - # import torch - # import torch.distributed as dist - # world_size = kvargs["world_size"] - # if world_size != 1: - # kvargs = obtain(kvargs) - # world_size = kvargs["world_size"] - # dist.init_process_group('nccl', init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}', - # rank=self.tp_rank, world_size=world_size) - # torch.cuda.set_device(self.tp_rank) + import torch + import torch.distributed as dist + + world_size = kvargs["world_size"] + self.tp_rank = kvargs["rank_id"] + client_port = kvargs["client_port"] + data_type = kvargs["data_type"] weight_dir = kvargs["weight_dir"] + model_kvargs = { + "tp_rank": self.tp_rank, + "world_size": world_size, + "weight_dir": weight_dir, + "client_port": client_port, + "data_type": data_type, + } + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(kvargs["nccl_port"] + 1) + dist.init_process_group(backend="nccl", rank=self.tp_rank, world_size=world_size) + torch.cuda.set_device(self.tp_rank) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + try: self.model_type = model_cfg["model_type"] if self.model_type == "qwen": - self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() + self.model = QWenVisionTransformer(model_kvargs, **model_cfg["visual"]).eval().bfloat16() elif self.model_type == "llava": - self.model = LlavaVisionModel() + self.model = LlavaVisionModel(model_kvargs) elif self.model_type == "internlmxcomposer2": - self.model = InternVisionModel() + self.model = InternVisionModel(model_kvargs) elif self.model_type == "internvl_chat": - # tp_rank = kvargs['rank_id'] - client_port = kvargs["client_port"] - data_type = kvargs["data_type"] - model_kvargs = {"weight_dir": weight_dir, "client_port": client_port, "data_type": data_type} self.model = InternVLVisionModel(model_kvargs) - else: raise Exception(f"can not support {self.model_type} now") + self.model.load_model(weight_dir) self.model = self.model.cuda() except Exception as e: @@ -116,6 +117,38 @@ async def encode(self, images): return ans -async def start_model_process(world_size): +def _init_env(port): + # 注册graceful 退出的处理 + from lightllm.utils.graceful_utils import graceful_registry + import inspect + + graceful_registry(inspect.currentframe().f_code.co_name) + + from rpyc.utils.server import ThreadedServer + + t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True}) + t.start() + return + + +async def start_model_process(port, world_size): if world_size == 1: return VisualModelRpcClient(VisualModelRpcServer(), world_size) + import multiprocessing + + proc = multiprocessing.Process(target=_init_env, args=(port,)) + proc.start() + await asyncio.sleep(2) + repeat_count = 0 + while repeat_count < 20: + try: + con = rpyc.connect("localhost", port, config={"allow_pickle": True}) + break + except BaseException: + await asyncio.sleep(1) + repeat_count += 1 + if repeat_count == 20: + raise Exception("init rpc env error!") + + assert proc.is_alive() + return VisualModelRpcClient(con.root, world_size, rpc_server_process=proc) From 9c03c3d101b7248232cfb47cb6b2db8d548e37e0 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 27 Sep 2024 11:23:27 +0000 Subject: [PATCH 3/8] visual dp --- .../internlm_xcomposer/internlm_visual.py | 60 +++++++++++++++---- lightllm/models/internvl/internvl_visual.py | 30 ++++++++-- lightllm/models/llava/llava_visual.py | 57 +++++++++++++++--- lightllm/models/qwen2_vl/qwen2_visual.py | 44 +++++++++++++- lightllm/models/qwen_vl/qwen_visual.py | 56 +++++++++++++---- lightllm/server/api_server.py | 11 +++- lightllm/server/visualserver/manager.py | 47 ++++++--------- .../visualserver/model_infer/model_rpc.py | 37 +++++++----- 8 files changed, 259 insertions(+), 83 deletions(-) diff --git a/lightllm/models/internlm_xcomposer/internlm_visual.py b/lightllm/models/internlm_xcomposer/internlm_visual.py index aed4dc93..dec1bb62 100644 --- a/lightllm/models/internlm_xcomposer/internlm_visual.py +++ b/lightllm/models/internlm_xcomposer/internlm_visual.py @@ -8,12 +8,18 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from rpyc.utils.classic import obtain +from io import BytesIO +import rpyc +from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed +from lightllm.utils.log_utils import init_logger class InternVisionModel: def __init__(self, kvargs): self.tp_rank_ = kvargs["tp_rank"] - self.world_size_ = kvargs["world_size"] + self.world_size_ = kvargs["vit_world_size"] + self.client_port = kvargs["client_port"] + self.cache_client = rpyc.connect("localhost", self.client_port) pass def load_projector_update(self, config, weight_dir): @@ -148,20 +154,54 @@ def forward(self, x): x = x.view(B, L, -1) return x - def encode(self, image_items: List[Union[str, Image.Image]]): - images = [] - for item in image_items: + def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): + img_tensors = [] + uuids = [] + valid_id = 0 + valid_ids = [] + for i, item in enumerate(image_items): if self.world_size_ != 1: item = obtain(item) if isinstance(item, Image.Image): - image = item + image = item.convert("RGB") + t = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"] + img_tensors.append(t) + elif isinstance(item, torch.Tensor): + img_tensors.append(item) + elif isinstance(item, int): + uuids.append(item) + image_data = read_shm(get_shm_name_data(item)) + image_data = Image.open(BytesIO(image_data)).convert("RGB") + t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] + img_tensors.append(t) elif item.startswith("http://") or item.startswith("https://"): import requests image = Image.open(requests.get(item, stream=True).raw) else: - image = Image.open(item) - image = self.image_processor(image.convert("RGB")).unsqueeze(0).to(self.device) - images.append(image) - images = torch.cat(images, dim=0) - return self.forward(images) + raise Exception("Unsupport input types: {} for {}".format(type(item), item)) + + cur_num = img_tensors[-1].shape[0] + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + img = torch.cat(img_tensors, dim=0) + pixel_values = img.to(self.device) + all_img_embeds = self.forward(pixel_values) + + if len(uuids) == 0: + return [all_img_embeds[start:end] for start, end in valid_ids] + else: + for i in range(len(uuids)): + uid = uuids[i] + if not self.cache_client.root.get_item_embed(uid): + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) + self.cache_client.root.set_item_embed(uuids[i]) + + return diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py index 1bc3c82a..27768c60 100644 --- a/lightllm/models/internvl/internvl_visual.py +++ b/lightllm/models/internvl/internvl_visual.py @@ -22,7 +22,9 @@ class InternVLVisionModel: def __init__(self, kvargs): self.tp_rank_ = kvargs["tp_rank"] - self.world_size_ = kvargs["world_size"] + self.world_size_ = kvargs["vit_world_size"] + self.client_port = kvargs["client_port"] + self.cache_client = rpyc.connect("localhost", self.client_port) pass def load_model(self, weight_dir): @@ -41,17 +43,27 @@ def load_model(self, weight_dir): def cuda(self): return self - def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]): + def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): img_tensors = [] valid_ids = [] valid_id = 0 + uuids = [] # load images to batch tensor + for i, url in enumerate(image_items): if self.world_size_ != 1: url = obtain(url) if isinstance(url, Image.Image): t = load_image(url, max_num=6) img_tensors.append(t) + elif isinstance(url, torch.Tensor): + img_tensors.append(url) + elif isinstance(url, int): + uuids.append(url) + image_data = read_shm(get_shm_name_data(url)) + image_data = Image.open(BytesIO(image_data)) + t = load_image(image_data) + img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(url), url)) @@ -62,9 +74,19 @@ def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]): if len(img_tensors) <= 0: return None - # (b, 3, 224, 224) imgs = torch.cat(img_tensors, dim=0) pixel_values = imgs.to(self.device, dtype=self.dtype) all_img_embeds = self.model.extract_feature(pixel_values) - return [all_img_embeds[start:end] for start, end in valid_ids] + + if len(uuids) == 0: + return [all_img_embeds[start:end] for start, end in valid_ids] + else: + for i in range(len(uuids)): + uid = uuids[i] + if not self.cache_client.root.get_item_embed(uid): + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) + self.cache_client.root.set_item_embed(uuids[i]) + return diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 92e030e2..34cac039 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -6,6 +6,9 @@ from typing import List, Union from safetensors import safe_open from rpyc.utils.classic import obtain +from io import BytesIO +import rpyc +from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.log_utils import init_logger @@ -15,7 +18,9 @@ class LlavaVisionModel: def __init__(self, kvargs): self.tp_rank_ = kvargs["tp_rank"] - self.world_size_ = kvargs["world_size"] + self.world_size_ = kvargs["vit_world_size"] + self.client_port = kvargs["client_port"] + self.cache_client = rpyc.connect("localhost", self.client_port) pass def load_model(self, weight_dir): @@ -126,20 +131,54 @@ def forward(self, x): x = x.view(B, L, -1) return x - def encode(self, image_items: List[Union[str, Image.Image]]): - images = [] - for item in image_items: + def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): + img_tensors = [] + uuids = [] + valid_id = 0 + valid_ids = [] + for i, item in enumerate(image_items): if self.world_size_ != 1: item = obtain(item) if isinstance(item, Image.Image): - image = item + image = item.convert("RGB") + t = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"] + img_tensors.append(t) + elif isinstance(item, torch.Tensor): + img_tensors.append(item) + elif isinstance(item, int): + uuids.append(item) + image_data = read_shm(get_shm_name_data(item)) + image_data = Image.open(BytesIO(image_data)).convert("RGB") + t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] + img_tensors.append(t) elif item.startswith("http://") or item.startswith("https://"): import requests image = Image.open(requests.get(item, stream=True).raw) else: - image = Image.open(item) - images.append(image.convert("RGB")) + raise Exception("Unsupport input types: {} for {}".format(type(item), item)) - images = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"] - return self.forward(images) + cur_num = img_tensors[-1].shape[0] + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + img = torch.cat(img_tensors, dim=0) + pixel_values = img.to(self.device) + all_img_embeds = self.forward(pixel_values) + + if len(uuids) == 0: + return [all_img_embeds[start:end] for start, end in valid_ids] + else: + for i in range(len(uuids)): + uid = uuids[i] + if not self.cache_client.root.get_item_embed(uid): + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) + self.cache_client.root.set_item_embed(uuids[i]) + + return diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index eced443a..d5ced3b5 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -32,6 +32,7 @@ from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed import rpyc from io import BytesIO +from rpyc.utils.classic import obtain from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.modeling_utils import PreTrainedModel @@ -295,6 +296,7 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: class Qwen2VisionTransformerPretrainedModel(nn.Module): def __init__( self, + kvargs, depth=32, embed_dim=1280, hidden_size=3584, @@ -307,6 +309,10 @@ def __init__( temporal_patch_size=2, **kwargs, ): + self.tp_tank_ = kvargs["tp_rank"] + self.world_size_ = kvargs["vit_world_size"] + self.client_port = kvargs["client_port"] + self.cache_client = rpyc.connect("localhost", self.client_port) super().__init__() self.depth = depth self.embed_dim = embed_dim @@ -427,12 +433,15 @@ def load_model(self, weight_dir): self.load_state_dict(weight_dict) - def encode(self, image_items: List[Union[str, Image.Image]]): + def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): img_tensors = [] valid_ids = [] valid_id = 0 img_grids = [] + uuids = [] for i, url in enumerate(image_items): + if self.world_size_ != 1: + url = obtain(url) if isinstance(url, Image.Image): t = get_image(url) image_inputs = self.processor.preprocess(images=t, return_tensors="pt") @@ -440,6 +449,26 @@ def encode(self, image_items: List[Union[str, Image.Image]]): image_grid_thw = image_inputs["image_grid_thw"] img_tensors.append(pixel_values) img_grids.append(image_grid_thw) + elif isinstance(url, torch.Tensor): + img_tensors.append(url) + elif isinstance(url, int): + uuids.append(url) + image_data = read_shm(get_shm_name_data(url)) + image_data = Image.open(BytesIO(image_data)) + image_data = get_image(image_data) + image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device="cuda") + image_grid_thw = image_inputs["image_grid_thw"] + img_tensors.append(pixel_values) + img_grids.append(image_grid_thw) + elif url.startswith("http://") or url.startswith("https://"): + image_data = Image.open(requests.get(url, stream=True).raw) + image_data = get_image(image_data) + image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device="cuda") + image_grid_thw = image_inputs["image_grid_thw"] + img_tensors.append(pixel_values) + img_grids.append(image_grid_thw) else: raise Exception("Unsupport input types: {} for {}".format(type(url), url)) @@ -461,4 +490,15 @@ def encode(self, image_items: List[Union[str, Image.Image]]): pixel_values = pixel_values.type(self.get_dtype()) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw).to(self.device) - return [all_img_embeds[start:end] for start, end in valid_ids] + if len(uuids) == 0: + return [all_img_embeds[start:end] for start, end in valid_ids] + else: + for i in range(len(uuids)): + uid = uuids[i] + if not self.cache_client.root.get_item_embed(uid): + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) + self.cache_client.root.set_item_embed(uuids[i]) + + return diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index a5529e66..d0b75428 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -11,7 +11,8 @@ from PIL import Image from typing import Callable, Optional, Sequence, Tuple, List, Union import numpy as np - +import rpyc +from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed import torch from torch import nn from torch.nn import functional as F @@ -346,7 +347,9 @@ def __init__( **kwargs ): self.tp_rank_ = kvargs["tp_rank"] - self.world_size_ = kvargs["world_size"] + self.world_size_ = kvargs["vit_world_size"] + self.client_port = kvargs["client_port"] + self.cache_client = rpyc.connect("localhost", self.client_port) super().__init__() image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) @@ -417,21 +420,52 @@ def forward(self, x: torch.Tensor): return x - def encode(self, image_items: List[Union[str, Image.Image]]): - images = [] - for item in image_items: + def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): + img_tensors = [] + uuids = [] + valid_id = 0 + valid_ids = [] + for i, item in enumerate(image_items): if self.world_size_ != 1: item = obtain(item) if isinstance(item, Image.Image): - image = item + image = item.convert("RGB") + t = self.image_transform(image) + img_tensors.append(t) + elif isinstance(item, torch.Tensor): + img_tensors.append(item) + elif isinstance(item, int): + uuids.append(item) + image_data = read_shm(get_shm_name_data(item)) + image_data = Image.open(BytesIO(image_data)).convert("RGB") + t = self.image_transform(image_data) + img_tensors.append(t) elif item.startswith("http://") or item.startswith("https://"): image = Image.open(requests.get(item, stream=True).raw) else: - image = Image.open(item) - image = image.convert("RGB") - images.append(self.image_transform(image)) - images = torch.stack(images, dim=0) - return self(images) + raise Exception("Unsupport input types: {} for {}".format(type(item), item)) + cur_num = img_tensors[-1].shape[0] + + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + if len(img_tensors) <= 0: + return None + + pixel_values = torch.stack(img_tensors, dim=0) + all_img_embeds = self(pixel_values) + + if len(uuids) == 0: + return [all_img_embeds[start:end] for start, end in valid_ids] + else: + for i in range(len(uuids)): + uid = uuids[i] + if not self.cache_client.root.get_item_embed(uid): + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) + self.cache_client.root.set_item_embed(uuids[i]) + + return def load_model(self, weight_dir): import os diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 7a5463f3..d4ac7d23 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -465,6 +465,13 @@ def make_argument_parser() -> argparse.ArgumentParser: "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" ) parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") + parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") + parser.add_argument( + "--visual_nccl_port", + type=int, + default=29500, + help="the visual_nccl_port to build a distributed environment for Vit", + ) parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) @@ -521,7 +528,9 @@ def main(): logger.info(f"all start args:{args}") - can_use_ports = alloc_can_use_network_port(num=6 + args.tp * 2, used_nccl_port=args.nccl_port) + can_use_ports = alloc_can_use_network_port( + num=6 + args.tp + args.visual_dp, used_nccl_port=[args.nccl_port, args.visual_nccl_port] + ) router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6] model_rpc_ports = can_use_ports[6 : 6 + args.tp] visual_model_rpc_ports = can_use_ports[6 + args.tp :] diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index d6e2d5a9..0f4c7782 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -15,6 +15,9 @@ from PIL import Image import time import torch +import logging + +logging.basicConfig(level=logging.INFO) class VisualManager: @@ -38,7 +41,7 @@ def __init__( self.waiting_reqs = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp - self.world_size = args.tp + self.vit_world_size = args.visual_dp self.infer_batch_size = infer_batch_size self.trust_remote_code = args.trust_remote_code self.args = args @@ -47,20 +50,24 @@ def __init__( async def wait_to_model_ready(self): self.model_rpcs: List[VisualModelRpcClient] = [] - for rank_id in range(self.world_size): - rpc_model = await start_model_process(port=self.model_rpcs_ports[rank_id], world_size=self.world_size) + for rank_id in range(self.vit_world_size): + print(f"self.vit_world_size is {self.vit_world_size}") + rpc_model = await start_model_process( + port=self.model_rpcs_ports[rank_id], vit_world_size=self.vit_world_size + ) self.model_rpcs.append(rpc_model) init_model_ret = [] - for rank_id in range(self.world_size): # async init model process + for rank_id in range(self.vit_world_size): # async init model process kvargs = { "weight_dir": self.model_weightdir, "trust_remote_code": self.trust_remote_code, - "world_size": self.world_size, + "vit_world_size": self.vit_world_size, "client_port": self.client_port, "rank_id": rank_id, "data_type": self.args.data_type, "nccl_port": self.args.nccl_port, + "visual_nccl_port": self.args.visual_nccl_port, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) @@ -77,38 +84,18 @@ async def infer_imgs(self, uuids): if len(uuids) == 0: return # uuids -> PIL Images - images = [] - for uid in uuids: - image_data = read_shm(get_shm_name_data(uid)) - images.append(Image.open(BytesIO(image_data))) - # print(" + got pil image:", images[-1].size, images[-1].mode) tasks = [] - for tp_rank in range(self.world_size): - assigned_images = [images[i] for i in range(tp_rank, len(images), self.world_size)] - assigned_uuids = [uuids[i] for i in range(tp_rank, len(uuids), self.world_size)] - if assigned_images: - task = asyncio.create_task(self.encode_and_store(tp_rank, assigned_images, assigned_uuids)) + for tp_rank in range(self.vit_world_size): + assigned_uuids = [uuids[i] for i in range(tp_rank, len(uuids), self.vit_world_size)] + if assigned_uuids: + logging.info(f"tp {tp_rank} is processing {assigned_uuids}") + task = asyncio.create_task(self.model_rpcs[tp_rank].encode(assigned_uuids)) tasks.append(task) # rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)] await asyncio.gather(*tasks) - torch.cuda.synchronize() return - async def encode_and_store(self, tp_rank, assigned_images, assigned_uuids): - # aynsc vit-encode - img_embeds = [] - result = await self.model_rpcs[tp_rank].encode(assigned_images) - img_embeds.extend(obtain(result)) - print(f"cuda{tp_rank} is work on uuid {assigned_uuids}") - print("len of img_embeds is", len(img_embeds)) - # write img_embed to shm - for i, uid in enumerate(assigned_uuids): - if not self.cache_client.root.get_item_embed(uid): - cur_embed_bytes = tensor2bytes(img_embeds[i]) - create_shm(get_shm_name_embed(uid), cur_embed_bytes) - self.cache_client.root.set_item_embed(uid) - async def loop_for_fwd(self): while True: if len(self.waiting_reqs) == 0: diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 015f3a02..bed04872 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -22,30 +22,35 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist - world_size = kvargs["world_size"] + vit_world_size = kvargs["vit_world_size"] self.tp_rank = kvargs["rank_id"] client_port = kvargs["client_port"] data_type = kvargs["data_type"] weight_dir = kvargs["weight_dir"] model_kvargs = { "tp_rank": self.tp_rank, - "world_size": world_size, + "vit_world_size": vit_world_size, "weight_dir": weight_dir, "client_port": client_port, "data_type": data_type, } - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(kvargs["nccl_port"] + 1) - dist.init_process_group(backend="nccl", rank=self.tp_rank, world_size=world_size) + dist.init_process_group( + backend="nccl", + init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', + rank=self.tp_rank, + world_size=vit_world_size, + ) torch.cuda.set_device(self.tp_rank) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: self.model_type = model_cfg["model_type"] if self.model_type == "qwen": - self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() + self.model = QWenVisionTransformer(model_kvargs, **model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": - self.model = Qwen2VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + self.model = ( + Qwen2VisionTransformerPretrainedModel(model_kvargs, **model_cfg["vision_config"]).eval().bfloat16() + ) elif self.model_type == "llava": self.model = LlavaVisionModel(model_kvargs) elif self.model_type == "internlmxcomposer2": @@ -79,11 +84,11 @@ def exposed_encode(self, images): class VisualModelRpcClient: - def __init__(self, model_rpc, world_size, rpc_server_process=None): + def __init__(self, model_rpc, vit_world_size, rpc_server_process=None): self.model: VisualModelRpcServer = model_rpc - self.world_size = world_size + self.vit_world_size = vit_world_size self.rpc_server_process = rpc_server_process - self.use_rpc = self.world_size != 1 + self.use_rpc = self.vit_world_size != 1 if self.use_rpc: def async_wrap(f): @@ -112,8 +117,8 @@ async def init_model(self, kvargs): else: return - async def encode(self, images): - ans = self._encode(images) + async def encode(self, uuids): + ans = self._encode(uuids) if self.use_rpc: return await ans else: @@ -134,9 +139,9 @@ def _init_env(port): return -async def start_model_process(port, world_size): - if world_size == 1: - return VisualModelRpcClient(VisualModelRpcServer(), world_size) +async def start_model_process(port, vit_world_size): + if vit_world_size == 1: + return VisualModelRpcClient(VisualModelRpcServer(), vit_world_size) import multiprocessing proc = multiprocessing.Process(target=_init_env, args=(port,)) @@ -154,4 +159,4 @@ async def start_model_process(port, world_size): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con.root, world_size, rpc_server_process=proc) + return VisualModelRpcClient(con.root, vit_world_size, rpc_server_process=proc) From 9c01e576cfc5c8b15e7cfa3657863f2460a88040 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Sun, 29 Sep 2024 12:51:01 +0000 Subject: [PATCH 4/8] vit dp --- lightllm/common/basemodel/basemodel.py | 17 +++++++- lightllm/common/basemodel/cuda_graph.py | 42 +++++++++++++++++++ lightllm/common/basemodel/infer_struct.py | 6 +++ lightllm/server/api_server.py | 20 ++++++++- lightllm/server/router/manager.py | 3 ++ .../model_infer/mode_backend/base_backend.py | 3 ++ lightllm/server/visualserver/manager.py | 8 ++-- 7 files changed, 92 insertions(+), 7 deletions(-) create mode 100644 lightllm/common/basemodel/cuda_graph.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index b130edb2..97c678dd 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -14,6 +14,7 @@ from lightllm.common.build_utils import repair_config from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.triton_kernel.splitfuse_copy_kv_index_to_req import splitfuse_copy_kv_index_to_req +from lightllm.common.basemodel.cuda_graph import CudaGraph torch.backends.cudnn.enabled = True @@ -50,6 +51,10 @@ def __init__(self, kvargs): assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time" self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) self.data_type = kvargs.get("data_type", "float16") + graph_max_batch_size = kvargs.get("graph_max_batch_size", 16) + graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8196) + disable_cudagraph = kvargs.get("disable_cudagraph", False) + self.graph = None if disable_cudagraph else CudaGraph(graph_max_batch_size, graph_max_len_in_batch) self._init_datatype() self._init_config() @@ -285,7 +290,9 @@ def _decode( infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager - alloc_mem = self.mem_manager.alloc_contiguous(batch_size) + # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致 + # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致 + alloc_mem = None if self.graph is not None else self.mem_manager.alloc_contiguous(batch_size) if alloc_mem is not None: infer_state.mem_is_contiguous = True infer_state.mem_index = alloc_mem[0] @@ -304,7 +311,13 @@ def _decode( copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) infer_state.init_some_extra_state(self, input_ids) - predict_logics = self._token_forward(input_ids, infer_state) + if self.graph.can_run(batch_size, max_len_in_batch): + if self.graph.need_capture(batch_size): + predict_logics = self.graph.capture_decode(self._token_forward, input_ids, infer_state) + else: + predict_logics = self.graph.replay(input_ids, infer_state) + else: + predict_logics = self._token_forward(input_ids, infer_state) return predict_logics @torch.no_grad() diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py new file mode 100644 index 00000000..3baa800e --- /dev/null +++ b/lightllm/common/basemodel/cuda_graph.py @@ -0,0 +1,42 @@ +import os +import torch + + +class CudaGraph: + # CudaGraph forward pass for the decoding stage. + + def __init__(self, max_batch_size=8, max_len_in_batch=8192): + self.graph = {} + self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + self.max_batch_size = max_batch_size + self.graph_max_len_in_batch = max_len_in_batch + + def can_run(self, batch_size, max_len_in_batch): + return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch + + def need_capture(self, batch_size): + return batch_size not in self.graph + + def capture_decode(self, decode_func, input_ids, infer_state): + graph_obj = torch.cuda.CUDAGraph() + batch_size = input_ids.shape[0] + infer_state.max_len_in_batch = self.graph_max_len_in_batch + infer_state.total_token_num = self.graph_max_len_in_batch * batch_size + # warmup + for _ in range(1): + torch.cuda.synchronize() + decode_func(input_ids, infer_state) + torch.cuda.synchronize() + with torch.cuda.graph(graph_obj, pool=self.mempool): + predict_logics = decode_func(input_ids, infer_state) + self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics) + graph_obj.replay() + return predict_logics + + def replay(self, input_ids, infer_state): + batch_size = input_ids.shape[0] + graph_obj, graph_input_ids, graph_infer_state, graph_predict_logics = self.graph[batch_size] + graph_input_ids.copy_(input_ids) + graph_infer_state.copy_for_cuda_graph(infer_state) + graph_obj.replay() + return graph_predict_logics diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 5d9f6393..fa1bf921 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -38,3 +38,9 @@ def __init__(self): def init_some_extra_state(self, model, input_ids: torch.Tensor): pass + + def copy_for_cuda_graph(self, new_infer_state): + for attr_name, attr_value in vars(new_infer_state).items(): + if isinstance(attr_value, torch.Tensor): + getattr(self, attr_name).copy_(attr_value) + return diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index d4ac7d23..67bfbf97 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -475,7 +475,21 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) - + parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") + parser.add_argument( + "--graph_max_batch_size", + type=int, + default=16, + help="""Maximum batch size that can be captured by the cuda graph for decodign stage. + The default value is 8. It will turn into eagar mode if encounters a larger value.""", + ) + parser.add_argument( + "--graph_max_len_in_batch", + type=int, + default=8192, + help="""Maximum sequence length that can be captured by the cuda graph for decodign stage. + The default value is 8192. It will turn into eagar mode if encounters a larger value. """, + ) return parser @@ -499,6 +513,10 @@ def main(): assert args.max_req_total_len <= args.max_total_token_num assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache" + # splitfuse_mode 和 cuda_graph 不能同时开启 + if args.splitfuse_mode: + assert args.disable_cudagraph + # 这些模式不能同时设置。 assert [args.splitfuse_mode, args.beam_mode, args.diverse_mode, args.token_healing_mode].count(True) <= 1 # 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。 diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 90e4aede..7edf2786 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -103,6 +103,9 @@ async def wait_to_model_ready(self): "eos_id": self.eos_id, "beam_mode": self.args.beam_mode, "diverse_mode": self.args.diverse_mode, + "graph_max_batch_size": self.args.graph_max_batch_size, + "graph_max_len_in_batch": self.args.graph_max_len_in_batch, + "disable_cudagraph": self.args.disable_cudagraph, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 7aa0a23f..8dadf6eb 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -99,6 +99,9 @@ def init_model(self, kvargs): "return_all_prompt_logics": self.return_all_prompt_logprobs, "use_dynamic_prompt_cache": self.use_dynamic_prompt_cache, "data_type": kvargs.get("data_type", "float16"), + "graph_max_batch_size": kvargs.get("graph_max_batch_size", 16), + "graph_max_len_in_batch": kvargs.get("graph_max_len_in_batch", 8196), + "disable_cudagraph": kvargs.get("disable_cudagraph", False), } is_weight_only_quant = any("w6a16" in mode_ or "w8a16" in mode_ or "w4a16" in mode_ for mode_ in self.mode) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 0f4c7782..a5c13423 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -15,9 +15,10 @@ from PIL import Image import time import torch -import logging -logging.basicConfig(level=logging.INFO) +# import logging + +# logging.basicConfig(level=logging.INFO) class VisualManager: @@ -51,7 +52,6 @@ async def wait_to_model_ready(self): self.model_rpcs: List[VisualModelRpcClient] = [] for rank_id in range(self.vit_world_size): - print(f"self.vit_world_size is {self.vit_world_size}") rpc_model = await start_model_process( port=self.model_rpcs_ports[rank_id], vit_world_size=self.vit_world_size ) @@ -88,7 +88,7 @@ async def infer_imgs(self, uuids): for tp_rank in range(self.vit_world_size): assigned_uuids = [uuids[i] for i in range(tp_rank, len(uuids), self.vit_world_size)] if assigned_uuids: - logging.info(f"tp {tp_rank} is processing {assigned_uuids}") + # logging.info(f"tp {tp_rank} is processing {assigned_uuids}") task = asyncio.create_task(self.model_rpcs[tp_rank].encode(assigned_uuids)) tasks.append(task) From 26c06ad4a90859f44f20f2481e68b9ad74519dd2 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Sun, 29 Sep 2024 20:52:34 +0800 Subject: [PATCH 5/8] Revert "vit dp" This reverts commit 9c01e576cfc5c8b15e7cfa3657863f2460a88040. --- lightllm/common/basemodel/basemodel.py | 17 +------- lightllm/common/basemodel/cuda_graph.py | 42 ------------------- lightllm/common/basemodel/infer_struct.py | 6 --- lightllm/server/api_server.py | 20 +-------- lightllm/server/router/manager.py | 3 -- .../model_infer/mode_backend/base_backend.py | 3 -- lightllm/server/visualserver/manager.py | 8 ++-- 7 files changed, 7 insertions(+), 92 deletions(-) delete mode 100644 lightllm/common/basemodel/cuda_graph.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 97c678dd..b130edb2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -14,7 +14,6 @@ from lightllm.common.build_utils import repair_config from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.triton_kernel.splitfuse_copy_kv_index_to_req import splitfuse_copy_kv_index_to_req -from lightllm.common.basemodel.cuda_graph import CudaGraph torch.backends.cudnn.enabled = True @@ -51,10 +50,6 @@ def __init__(self, kvargs): assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time" self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) self.data_type = kvargs.get("data_type", "float16") - graph_max_batch_size = kvargs.get("graph_max_batch_size", 16) - graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8196) - disable_cudagraph = kvargs.get("disable_cudagraph", False) - self.graph = None if disable_cudagraph else CudaGraph(graph_max_batch_size, graph_max_len_in_batch) self._init_datatype() self._init_config() @@ -290,9 +285,7 @@ def _decode( infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager - # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致 - # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致 - alloc_mem = None if self.graph is not None else self.mem_manager.alloc_contiguous(batch_size) + alloc_mem = self.mem_manager.alloc_contiguous(batch_size) if alloc_mem is not None: infer_state.mem_is_contiguous = True infer_state.mem_index = alloc_mem[0] @@ -311,13 +304,7 @@ def _decode( copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) infer_state.init_some_extra_state(self, input_ids) - if self.graph.can_run(batch_size, max_len_in_batch): - if self.graph.need_capture(batch_size): - predict_logics = self.graph.capture_decode(self._token_forward, input_ids, infer_state) - else: - predict_logics = self.graph.replay(input_ids, infer_state) - else: - predict_logics = self._token_forward(input_ids, infer_state) + predict_logics = self._token_forward(input_ids, infer_state) return predict_logics @torch.no_grad() diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py deleted file mode 100644 index 3baa800e..00000000 --- a/lightllm/common/basemodel/cuda_graph.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import torch - - -class CudaGraph: - # CudaGraph forward pass for the decoding stage. - - def __init__(self, max_batch_size=8, max_len_in_batch=8192): - self.graph = {} - self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None - self.max_batch_size = max_batch_size - self.graph_max_len_in_batch = max_len_in_batch - - def can_run(self, batch_size, max_len_in_batch): - return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch - - def need_capture(self, batch_size): - return batch_size not in self.graph - - def capture_decode(self, decode_func, input_ids, infer_state): - graph_obj = torch.cuda.CUDAGraph() - batch_size = input_ids.shape[0] - infer_state.max_len_in_batch = self.graph_max_len_in_batch - infer_state.total_token_num = self.graph_max_len_in_batch * batch_size - # warmup - for _ in range(1): - torch.cuda.synchronize() - decode_func(input_ids, infer_state) - torch.cuda.synchronize() - with torch.cuda.graph(graph_obj, pool=self.mempool): - predict_logics = decode_func(input_ids, infer_state) - self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics) - graph_obj.replay() - return predict_logics - - def replay(self, input_ids, infer_state): - batch_size = input_ids.shape[0] - graph_obj, graph_input_ids, graph_infer_state, graph_predict_logics = self.graph[batch_size] - graph_input_ids.copy_(input_ids) - graph_infer_state.copy_for_cuda_graph(infer_state) - graph_obj.replay() - return graph_predict_logics diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index fa1bf921..5d9f6393 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -38,9 +38,3 @@ def __init__(self): def init_some_extra_state(self, model, input_ids: torch.Tensor): pass - - def copy_for_cuda_graph(self, new_infer_state): - for attr_name, attr_value in vars(new_infer_state).items(): - if isinstance(attr_value, torch.Tensor): - getattr(self, attr_name).copy_(attr_value) - return diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 67bfbf97..d4ac7d23 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -475,21 +475,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) - parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") - parser.add_argument( - "--graph_max_batch_size", - type=int, - default=16, - help="""Maximum batch size that can be captured by the cuda graph for decodign stage. - The default value is 8. It will turn into eagar mode if encounters a larger value.""", - ) - parser.add_argument( - "--graph_max_len_in_batch", - type=int, - default=8192, - help="""Maximum sequence length that can be captured by the cuda graph for decodign stage. - The default value is 8192. It will turn into eagar mode if encounters a larger value. """, - ) + return parser @@ -513,10 +499,6 @@ def main(): assert args.max_req_total_len <= args.max_total_token_num assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache" - # splitfuse_mode 和 cuda_graph 不能同时开启 - if args.splitfuse_mode: - assert args.disable_cudagraph - # 这些模式不能同时设置。 assert [args.splitfuse_mode, args.beam_mode, args.diverse_mode, args.token_healing_mode].count(True) <= 1 # 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。 diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 7edf2786..90e4aede 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -103,9 +103,6 @@ async def wait_to_model_ready(self): "eos_id": self.eos_id, "beam_mode": self.args.beam_mode, "diverse_mode": self.args.diverse_mode, - "graph_max_batch_size": self.args.graph_max_batch_size, - "graph_max_len_in_batch": self.args.graph_max_len_in_batch, - "disable_cudagraph": self.args.disable_cudagraph, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8dadf6eb..7aa0a23f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -99,9 +99,6 @@ def init_model(self, kvargs): "return_all_prompt_logics": self.return_all_prompt_logprobs, "use_dynamic_prompt_cache": self.use_dynamic_prompt_cache, "data_type": kvargs.get("data_type", "float16"), - "graph_max_batch_size": kvargs.get("graph_max_batch_size", 16), - "graph_max_len_in_batch": kvargs.get("graph_max_len_in_batch", 8196), - "disable_cudagraph": kvargs.get("disable_cudagraph", False), } is_weight_only_quant = any("w6a16" in mode_ or "w8a16" in mode_ or "w4a16" in mode_ for mode_ in self.mode) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a5c13423..0f4c7782 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -15,10 +15,9 @@ from PIL import Image import time import torch +import logging -# import logging - -# logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.INFO) class VisualManager: @@ -52,6 +51,7 @@ async def wait_to_model_ready(self): self.model_rpcs: List[VisualModelRpcClient] = [] for rank_id in range(self.vit_world_size): + print(f"self.vit_world_size is {self.vit_world_size}") rpc_model = await start_model_process( port=self.model_rpcs_ports[rank_id], vit_world_size=self.vit_world_size ) @@ -88,7 +88,7 @@ async def infer_imgs(self, uuids): for tp_rank in range(self.vit_world_size): assigned_uuids = [uuids[i] for i in range(tp_rank, len(uuids), self.vit_world_size)] if assigned_uuids: - # logging.info(f"tp {tp_rank} is processing {assigned_uuids}") + logging.info(f"tp {tp_rank} is processing {assigned_uuids}") task = asyncio.create_task(self.model_rpcs[tp_rank].encode(assigned_uuids)) tasks.append(task) From b33bc451dbdcab44cb12274fb3e1841ec08106b5 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Sun, 29 Sep 2024 13:07:40 +0000 Subject: [PATCH 6/8] visual dp --- lightllm/server/visualserver/manager.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 0f4c7782..97c322c0 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -15,9 +15,6 @@ from PIL import Image import time import torch -import logging - -logging.basicConfig(level=logging.INFO) class VisualManager: @@ -51,7 +48,6 @@ async def wait_to_model_ready(self): self.model_rpcs: List[VisualModelRpcClient] = [] for rank_id in range(self.vit_world_size): - print(f"self.vit_world_size is {self.vit_world_size}") rpc_model = await start_model_process( port=self.model_rpcs_ports[rank_id], vit_world_size=self.vit_world_size ) @@ -88,7 +84,6 @@ async def infer_imgs(self, uuids): for tp_rank in range(self.vit_world_size): assigned_uuids = [uuids[i] for i in range(tp_rank, len(uuids), self.vit_world_size)] if assigned_uuids: - logging.info(f"tp {tp_rank} is processing {assigned_uuids}") task = asyncio.create_task(self.model_rpcs[tp_rank].encode(assigned_uuids)) tasks.append(task) From c1adfe8372edae9da9b279f171b15e0466deacfd Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 8 Oct 2024 04:36:34 +0000 Subject: [PATCH 7/8] add visual_dp --- README.md | 2 +- lightllm/server/api_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e0e3609d..da76828d 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram > InternVL-Chat(InternLM2) needs to set the parameter '--eos_id 92542 --trust_remote_code'. -> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code'. +> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code', and use 'pip install git+https://github.com/huggingface/transformers' to upgrade to the latest version. > Stablelm needs to set the parameter '--trust_remote_code'. diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 4ba2a522..67bfbf97 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -512,7 +512,7 @@ def main(): assert args.max_req_input_len < args.max_req_total_len assert args.max_req_total_len <= args.max_total_token_num assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache" - + # splitfuse_mode 和 cuda_graph 不能同时开启 if args.splitfuse_mode: assert args.disable_cudagraph From 14076bf2ed42a9286d38043fc095ba596d5f1043 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 18 Oct 2024 10:21:43 +0000 Subject: [PATCH 8/8] vit_dp_and_tp --- .../internlm_xcomposer/internlm_visual.py | 3 +- lightllm/models/internvl/internvl_visual.py | 13 +++-- lightllm/models/llava/llava_visual.py | 5 +- lightllm/models/qwen2_vl/qwen2_visual.py | 10 ++-- lightllm/models/qwen_vl/qwen_visual.py | 2 + lightllm/server/api_server.py | 34 ++++++++--- lightllm/server/visualserver/manager.py | 56 +++++++++++-------- .../visualserver/model_infer/model_rpc.py | 46 +++++++++------ 8 files changed, 109 insertions(+), 60 deletions(-) diff --git a/lightllm/models/internlm_xcomposer/internlm_visual.py b/lightllm/models/internlm_xcomposer/internlm_visual.py index dec1bb62..8753970a 100644 --- a/lightllm/models/internlm_xcomposer/internlm_visual.py +++ b/lightllm/models/internlm_xcomposer/internlm_visual.py @@ -20,6 +20,7 @@ def __init__(self, kvargs): self.world_size_ = kvargs["vit_world_size"] self.client_port = kvargs["client_port"] self.cache_client = rpyc.connect("localhost", self.client_port) + self.device = torch.device(f'cuda:{self.visual_gpu}') pass def load_projector_update(self, config, weight_dir): @@ -121,7 +122,7 @@ def cuda(self): self.vision_tower = self.vision_tower.cuda() for i in range(len(self.projector_weights)): self.projector_weights[i] = self.projector_weights[i].cuda() - self.device = torch.device(f"cuda:{self.tp_rank_}") + torch.cuda.set_device(self.device) return self # batch images infer diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py index 27768c60..cdd2c3ae 100644 --- a/lightllm/models/internvl/internvl_visual.py +++ b/lightllm/models/internvl/internvl_visual.py @@ -21,15 +21,16 @@ class InternVLVisionModel: def __init__(self, kvargs): - self.tp_rank_ = kvargs["tp_rank"] - self.world_size_ = kvargs["vit_world_size"] + self.tp_rank_id = kvargs["tp_rank_id"] + self.vit_tp = kvargs["vit_tp"] self.client_port = kvargs["client_port"] self.cache_client = rpyc.connect("localhost", self.client_port) + self.visual_gpu = kvargs["visual_gpu"] + self.device = torch.device(f'cuda:{self.visual_gpu}') pass def load_model(self, weight_dir): assert torch.cuda.is_available() - self.device = torch.device(f"cuda:{self.tp_rank_}") self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 self.config = json.load(open(os.path.join(weight_dir, "config.json"))) self.model = AutoModel.from_pretrained( @@ -51,7 +52,7 @@ def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): # load images to batch tensor for i, url in enumerate(image_items): - if self.world_size_ != 1: + if self.vit_tp != 1: url = obtain(url) if isinstance(url, Image.Image): t = load_image(url, max_num=6) @@ -75,9 +76,11 @@ def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): if len(img_tensors) <= 0: return None # (b, 3, 224, 224) + torch.cuda.set_device(self.device) imgs = torch.cat(img_tensors, dim=0) - pixel_values = imgs.to(self.device, dtype=self.dtype) + pixel_values = imgs.to(device=self.device, dtype=self.dtype) all_img_embeds = self.model.extract_feature(pixel_values) + current_device = torch.cuda.current_device() if len(uuids) == 0: return [all_img_embeds[start:end] for start, end in valid_ids] diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 34cac039..d14a5afb 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -21,6 +21,8 @@ def __init__(self, kvargs): self.world_size_ = kvargs["vit_world_size"] self.client_port = kvargs["client_port"] self.cache_client = rpyc.connect("localhost", self.client_port) + self.visual_gpu = kvargs["visual_gpu"] + self.device = torch.device(f'cuda:{self.visual_gpu}') pass def load_model(self, weight_dir): @@ -102,7 +104,8 @@ def cuda(self): self.vision_tower = self.vision_tower.cuda() for k, v in self.projector_weights.items(): self.projector_weights[k] = v.cuda() - self.device = torch.device(f"cuda:{self.tp_rank_}") + self.device = torch.device(self.device) + torch.cuda.set_device(self.device) return self # batch images infer diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index d5ced3b5..a9878ca4 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -313,6 +313,8 @@ def __init__( self.world_size_ = kvargs["vit_world_size"] self.client_port = kvargs["client_port"] self.cache_client = rpyc.connect("localhost", self.client_port) + self.visual_gpu = kvargs["visual_gpu"] + self.device = torch.device(f'cuda:{self.visual_gpu}') super().__init__() self.depth = depth self.embed_dim = embed_dim @@ -388,11 +390,11 @@ def rot_pos_emb(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.to( dtype=self.get_dtype(), - device=torch.device("cuda"), + device=self.device, ) grid_thw = grid_thw.to( dtype=torch.int32, - device=torch.device("cuda"), + device=self.device, ) hidden_states = self.patch_embed(hidden_states) @@ -445,7 +447,7 @@ def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): if isinstance(url, Image.Image): t = get_image(url) image_inputs = self.processor.preprocess(images=t, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device="cuda") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device=self.device) image_grid_thw = image_inputs["image_grid_thw"] img_tensors.append(pixel_values) img_grids.append(image_grid_thw) @@ -457,7 +459,7 @@ def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]): image_data = Image.open(BytesIO(image_data)) image_data = get_image(image_data) image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device="cuda") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device=self.device) image_grid_thw = image_inputs["image_grid_thw"] img_tensors.append(pixel_values) img_grids.append(image_grid_thw) diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index d0b75428..050b9f47 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -350,6 +350,8 @@ def __init__( self.world_size_ = kvargs["vit_world_size"] self.client_port = kvargs["client_port"] self.cache_client = rpyc.connect("localhost", self.client_port) + self.visual_gpu = kvargs["visual_gpu"] + self.device = torch.device(f'cuda:{self.visual_gpu}') super().__init__() image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 67bfbf97..baf80e5b 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -465,12 +465,14 @@ def make_argument_parser() -> argparse.ArgumentParser: "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" ) parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") + parser.add_argument("--visual_gpu_ids", type=str, default="0", help="Comma separated GPU IDs to use, e.g., '0,1,2'") + parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") parser.add_argument( - "--visual_nccl_port", - type=int, - default=29500, - help="the visual_nccl_port to build a distributed environment for Vit", + "--visual_nccl_ports", + type=str, + default="29500", + help="Comma-separated list of NCCL ports to build a distributed environment for Vit." ) parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" @@ -527,6 +529,20 @@ def main(): # 部分模式还不能支持与高级动态调度算法协同,to do. if args.beam_mode or args.diverse_mode: assert args.router_token_ratio == 0.0 + + # 检查GPU数量是否足够 + visual_gpu_ids = [int(gpu_id) for gpu_id in args.visual_gpu_ids.split(",")] + total_required_gpus = args.visual_dp * args.visual_tp + if len(visual_gpu_ids) < total_required_gpus: + raise ValueError(f"Not enough GPUs specified. You need at least {total_required_gpus} GPUs, but got {len(visual_gpu_ids)}. Use --visual_gpu_ids to set. i.g. --visual_gpu_ids 0,2,3,5") + else: + args.visual_gpu_ids = visual_gpu_ids[:total_required_gpus] + + visual_nccl_port_ids = [int(nccl_port_id) for nccl_port_id in args.visual_nccl_ports.split(",")] + if len(visual_nccl_port_ids) != args.visual_dp: + raise ValueError(f"The number of ports ({len(visual_nccl_port_ids)}) does not match vit_dp ({args.visual_dp}).") + + args.visual_nccl_port = visual_nccl_port_ids if not args.splitfuse_mode: # 普通模式下 @@ -547,11 +563,15 @@ def main(): logger.info(f"all start args:{args}") can_use_ports = alloc_can_use_network_port( - num=6 + args.tp + args.visual_dp, used_nccl_port=[args.nccl_port, args.visual_nccl_port] + num=6 + args.tp + args.visual_dp * args.visual_tp, used_nccl_port=[args.nccl_port, args.visual_nccl_port] ) router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6] model_rpc_ports = can_use_ports[6 : 6 + args.tp] - visual_model_rpc_ports = can_use_ports[6 + args.tp :] + + visual_model_tp_ports = [] + for dp_index in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[6 + args.tp + dp_index * args.visual_tp : 6 + args.tp + (dp_index + 1) * args.visual_tp] + visual_model_tp_ports.append(tp_ports_for_dp) if args.enable_multimodal: start_submodule_processes( @@ -600,7 +620,7 @@ def main(): start_visual_process, ], start_args=[ - (args, router_port, visual_port, cache_port, visual_model_rpc_ports), + (args, router_port, visual_port, cache_port, visual_model_tp_ports), ], ) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 97c322c0..ccf45f0c 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -24,7 +24,7 @@ def __init__( router_port, visual_port, client_port, - model_rpc_ports, + visual_model_rpc_ports, infer_batch_size=4, ): context = zmq.asyncio.Context(2) @@ -38,34 +38,42 @@ def __init__( self.waiting_reqs = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp - self.vit_world_size = args.visual_dp + self.vit_dp = args.visual_dp + self.vit_tp = args.visual_tp self.infer_batch_size = infer_batch_size self.trust_remote_code = args.trust_remote_code self.args = args - self.model_rpcs_ports = model_rpc_ports + self.visual_model_rpcs_ports = visual_model_rpc_ports async def wait_to_model_ready(self): - self.model_rpcs: List[VisualModelRpcClient] = [] - for rank_id in range(self.vit_world_size): - rpc_model = await start_model_process( - port=self.model_rpcs_ports[rank_id], vit_world_size=self.vit_world_size - ) - self.model_rpcs.append(rpc_model) + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + for dp_rank_id in range(self.vit_dp): + tp_ports_each_dp = self.visual_model_rpcs_ports[dp_rank_id] + for tp_rank_id in range(self.vit_tp): + rpc_model = await start_model_process( + port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp + ) + self.model_rpcs[dp_rank_id].append(rpc_model) + init_model_ret = [] - for rank_id in range(self.vit_world_size): # async init model process - kvargs = { - "weight_dir": self.model_weightdir, - "trust_remote_code": self.trust_remote_code, - "vit_world_size": self.vit_world_size, - "client_port": self.client_port, - "rank_id": rank_id, - "data_type": self.args.data_type, - "nccl_port": self.args.nccl_port, - "visual_nccl_port": self.args.visual_nccl_port, - } - init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) + for dp_rank_id in range(self.vit_dp): # async init model process + for tp_rank_id in range(self.vit_tp): + kvargs = { + "weight_dir": self.model_weightdir, + "trust_remote_code": self.trust_remote_code, + "vit_dp": self.vit_dp, + "vit_tp": self.vit_tp, + "client_port": self.client_port, + "tp_rank_id": tp_rank_id, + "dp_rank_id": dp_rank_id, + "vit_rank_id" : dp_rank_id * self.vit_tp + tp_rank_id, + "data_type": self.args.data_type, + "visual_nccl_port": self.args.visual_nccl_port[dp_rank_id], + "visual_gpu_ids":self.args.visual_gpu_ids + } + init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) return @@ -81,10 +89,10 @@ async def infer_imgs(self, uuids): return # uuids -> PIL Images tasks = [] - for tp_rank in range(self.vit_world_size): - assigned_uuids = [uuids[i] for i in range(tp_rank, len(uuids), self.vit_world_size)] + for vit_dp_rank in range(self.vit_dp): + assigned_uuids = [uuids[i] for i in range(vit_dp_rank, len(uuids), self.vit_dp)] if assigned_uuids: - task = asyncio.create_task(self.model_rpcs[tp_rank].encode(assigned_uuids)) + task = asyncio.create_task(self.model_rpcs[vit_dp_rank][0].encode(assigned_uuids)) tasks.append(task) # rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)] diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index bed04872..6ebd1210 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -22,25 +22,35 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist - vit_world_size = kvargs["vit_world_size"] - self.tp_rank = kvargs["rank_id"] + self.vit_dp = kvargs["vit_dp"] + self.vit_tp = kvargs["vit_tp"] + self.dp_rank_id = kvargs["dp_rank_id"] + self.tp_rank_id = kvargs["tp_rank_id"] client_port = kvargs["client_port"] data_type = kvargs["data_type"] weight_dir = kvargs["weight_dir"] + visual_gpu_ids = kvargs["visual_gpu_ids"] + visual_nccl_port = kvargs["visual_nccl_port"] + self.vit_rank_id = kvargs["vit_rank_id"] + model_kvargs = { - "tp_rank": self.tp_rank, - "vit_world_size": vit_world_size, + "tp_rank_id": self.tp_rank_id, + "vit_tp": self.vit_tp, "weight_dir": weight_dir, "client_port": client_port, "data_type": data_type, + "vit_rank_id":self.vit_rank_id, + "visual_gpu":visual_gpu_ids[self.vit_rank_id] } - dist.init_process_group( - backend="nccl", - init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', - rank=self.tp_rank, - world_size=vit_world_size, - ) - torch.cuda.set_device(self.tp_rank) + if self.vit_tp != 1: + dist.init_process_group( + backend="nccl", + init_method=f'tcp://127.0.0.1:{visual_nccl_port}',# 改这里 tp 才需要nccl, dp不需要, api_server里也要改(需要port应该,nccl_port不需要把?) + rank=self.tp_rank_id, + world_size=self.vit_tp, + ) + print(f"self.tp_rank_id:{self.tp_rank_id}, self.vit_rank_id:{self.vit_rank_id},visual_gpu_ids[self.vit_rank_id] is {visual_gpu_ids[self.vit_rank_id]} ") + torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: @@ -84,11 +94,11 @@ def exposed_encode(self, images): class VisualModelRpcClient: - def __init__(self, model_rpc, vit_world_size, rpc_server_process=None): + def __init__(self, model_rpc, vit_tp, rpc_server_process=None): self.model: VisualModelRpcServer = model_rpc - self.vit_world_size = vit_world_size + self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process - self.use_rpc = self.vit_world_size != 1 + self.use_rpc = self.vit_tp != 1 if self.use_rpc: def async_wrap(f): @@ -139,9 +149,9 @@ def _init_env(port): return -async def start_model_process(port, vit_world_size): - if vit_world_size == 1: - return VisualModelRpcClient(VisualModelRpcServer(), vit_world_size) +async def start_model_process(port, vit_tp): + if vit_tp == 1: + return VisualModelRpcClient(VisualModelRpcServer(), vit_tp) import multiprocessing proc = multiprocessing.Process(target=_init_env, args=(port,)) @@ -159,4 +169,4 @@ async def start_model_process(port, vit_world_size): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_world_size, rpc_server_process=proc) + return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc)